Coverage for pySDC/implementations/transfer_classes/TransferMesh_MPIFFT.py: 86%

70 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 14:07 +0000

1from pySDC.core.Errors import TransferError 

2from pySDC.core.SpaceTransfer import space_transfer 

3from mpi4py_fft import PFFT, newDistArray 

4 

5 

6class fft_to_fft(space_transfer): 

7 """ 

8 Custom base_transfer class, implements Transfer.py 

9 

10 This implementation can restrict and prolong between PMESH datatypes meshes with FFT for periodic boundaries 

11 

12 """ 

13 

14 def __init__(self, fine_prob, coarse_prob, params): 

15 """ 

16 Initialization routine 

17 

18 Args: 

19 fine_prob: fine problem 

20 coarse_prob: coarse problem 

21 params: parameters for the transfer operators 

22 """ 

23 # invoke super initialization 

24 super().__init__(fine_prob, coarse_prob, params) 

25 

26 assert self.fine_prob.spectral == self.coarse_prob.spectral 

27 

28 self.spectral = self.fine_prob.spectral 

29 

30 Nf = list(self.fine_prob.fft.global_shape()) 

31 Nc = list(self.coarse_prob.fft.global_shape()) 

32 self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)] 

33 axes = tuple(range(len(Nf))) 

34 

35 fft_args = {} 

36 useGPU = 'cupy' in self.fine_prob.dtype_u.__name__.lower() 

37 if useGPU: 

38 fft_args['backend'] = 'cupy' 

39 fft_args['comm_backend'] = 'NCCL' 

40 

41 self.fft_pad = PFFT( 

42 self.coarse_prob.comm, 

43 Nc, 

44 padding=self.ratio, 

45 axes=axes, 

46 dtype=self.coarse_prob.fft.dtype(False), 

47 slab=True, 

48 **fft_args, 

49 ) 

50 

51 def restrict(self, F): 

52 """ 

53 Restriction implementation 

54 

55 Args: 

56 F: the fine level data (easier to access than via the fine attribute) 

57 """ 

58 G = type(F)(self.coarse_prob.init) 

59 

60 def _restrict(fine, coarse): 

61 if self.spectral: 

62 if hasattr(self.fine_prob, 'ncomp'): 

63 for i in range(self.fine_prob.ncomp): 

64 if fine.shape[-1] == self.fine_prob.ncomp: 

65 tmpF = newDistArray(self.fine_prob.fft, False) 

66 tmpF = self.fine_prob.fft.backward(fine[..., i], tmpF) 

67 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

68 coarse[..., i] = self.coarse_prob.fft.forward(tmpG, coarse[..., i]) 

69 elif fine.shape[0] == self.fine_prob.ncomp: 

70 tmpF = newDistArray(self.fine_prob.fft, False) 

71 tmpF = self.fine_prob.fft.backward(fine[i, ...], tmpF) 

72 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

73 coarse[i, ...] = self.coarse_prob.fft.forward(tmpG, coarse[i, ...]) 

74 else: 

75 raise TransferError('Don\'t know how to restrict for this problem with multiple components') 

76 else: 

77 tmpF = self.fine_prob.fft.backward(fine) 

78 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

79 coarse[:] = self.coarse_prob.fft.forward(tmpG, coarse) 

80 else: 

81 coarse[:] = fine[:: int(self.ratio[0]), :: int(self.ratio[1])] 

82 

83 if hasattr(type(F), 'components'): 

84 for comp in F.components: 

85 _restrict(F.__getattr__(comp), G.__getattr__(comp)) 

86 elif type(F).__name__ in ['mesh', 'cupy_mesh']: 

87 _restrict(F, G) 

88 else: 

89 raise TransferError('Wrong data type for restriction, got %s' % type(F)) 

90 

91 return G 

92 

93 def prolong(self, G): 

94 """ 

95 Prolongation implementation 

96 

97 Args: 

98 G: the coarse level data (easier to access than via the coarse attribute) 

99 """ 

100 F = type(G)(self.fine_prob.init) 

101 

102 def _prolong(coarse, fine): 

103 if self.spectral: 

104 if hasattr(self.fine_prob, 'ncomp'): 

105 for i in range(self.fine_prob.ncomp): 

106 if coarse.shape[-1] == self.fine_prob.ncomp: 

107 tmpF = self.fft_pad.backward(coarse[..., i]) 

108 fine[..., i] = self.fine_prob.fft.forward(tmpF, fine[..., i]) 

109 elif coarse.shape[0] == self.fine_prob.ncomp: 

110 tmpF = self.fft_pad.backward(coarse[i, ...]) 

111 fine[i, ...] = self.fine_prob.fft.forward(tmpF, fine[i, ...]) 

112 else: 

113 raise TransferError('Don\'t know how to prolong for this problem with multiple components') 

114 

115 else: 

116 tmpF = self.fft_pad.backward(coarse) 

117 fine[:] = self.fine_prob.fft.forward(tmpF, fine) 

118 else: 

119 if hasattr(self.fine_prob, 'ncomp'): 

120 for i in range(self.fine_prob.ncomp): 

121 G_hat = self.coarse_prob.fft.forward(coarse[..., i]) 

122 fine[..., i] = self.fft_pad.backward(G_hat, fine[..., i]) 

123 else: 

124 G_hat = self.coarse_prob.fft.forward(coarse) 

125 fine[:] = self.fft_pad.backward(G_hat, fine) 

126 

127 if hasattr(type(F), 'components'): 

128 for comp in F.components: 

129 _prolong(G.__getattr__(comp), F.__getattr__(comp)) 

130 elif type(G).__name__ in ['mesh', 'cupy_mesh']: 

131 _prolong(G, F) 

132 

133 else: 

134 raise TransferError('Unknown data type, got %s' % type(G)) 

135 

136 return F