Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 14:07 +0000

1import cupy as cp 

2from pySDC.core.Errors import DataError 

3 

4try: 

5 from mpi4py import MPI 

6except ImportError: 

7 MPI = None 

8 

9 

10class cupy_mesh(cp.ndarray): 

11 """ 

12 CuPy-based datatype for serial or parallel meshes. 

13 """ 

14 

15 def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None): 

16 """ 

17 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

18 

19 Args: 

20 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

21 val: value to initialize 

22 

23 Returns: 

24 obj of type mesh 

25 

26 """ 

27 if isinstance(init, cupy_mesh): 

28 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, strides=strides, order=order) 

29 obj[:] = init[:] 

30 obj._comm = init._comm 

31 elif ( 

32 isinstance(init, tuple) 

33 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

34 and isinstance(init[2], cp.dtype) 

35 ): 

36 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], strides=strides, order=order) 

37 obj.fill(val) 

38 obj._comm = init[1] 

39 else: 

40 raise NotImplementedError(type(init)) 

41 return obj 

42 

43 @property 

44 def comm(self): 

45 """ 

46 Getter for the communicator 

47 """ 

48 return self._comm 

49 

50 def __array_finalize__(self, obj): 

51 """ 

52 Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. 

53 """ 

54 if obj is None: 

55 return 

56 self._comm = getattr(obj, '_comm', None) 

57 

58 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

59 """ 

60 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

61 """ 

62 args = [] 

63 comm = None 

64 for _, input_ in enumerate(inputs): 

65 if isinstance(input_, cupy_mesh): 

66 args.append(input_.view(cp.ndarray)) 

67 comm = input_.comm 

68 else: 

69 args.append(input_) 

70 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh) 

71 if not method == 'reduce': 

72 results._comm = comm 

73 return results 

74 

75 def __abs__(self): 

76 """ 

77 Overloading the abs operator 

78 

79 Returns: 

80 float: absolute maximum of all mesh values 

81 """ 

82 # take absolute values of the mesh values 

83 local_absval = float(cp.amax(cp.ndarray.__abs__(self))) 

84 

85 if self.comm is not None: 

86 if self.comm.Get_size() > 1: 

87 global_absval = 0.0 

88 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval) 

89 else: 

90 global_absval = local_absval 

91 else: 

92 global_absval = local_absval 

93 

94 return float(global_absval) 

95 

96 def isend(self, dest=None, tag=None, comm=None): 

97 """ 

98 Routine for sending data forward in time (non-blocking) 

99 

100 Args: 

101 dest (int): target rank 

102 tag (int): communication tag 

103 comm: communicator 

104 

105 Returns: 

106 request handle 

107 """ 

108 return comm.Issend(self[:], dest=dest, tag=tag) 

109 

110 def irecv(self, source=None, tag=None, comm=None): 

111 """ 

112 Routine for receiving in time 

113 

114 Args: 

115 source (int): source rank 

116 tag (int): communication tag 

117 comm: communicator 

118 

119 Returns: 

120 None 

121 """ 

122 return comm.Irecv(self[:], source=source, tag=tag) 

123 

124 def bcast(self, root=None, comm=None): 

125 """ 

126 Routine for broadcasting values 

127 

128 Args: 

129 root (int): process with value to broadcast 

130 comm: communicator 

131 

132 Returns: 

133 broadcasted values 

134 """ 

135 comm.Bcast(self[:], root=root) 

136 return self 

137 

138 

139class CuPyMultiComponentMesh(cupy_mesh): 

140 r""" 

141 Generic mesh with multiple components. 

142 

143 To make a specific multi-component mesh, derive from this class and list the components as strings in the class 

144 attribute ``components``. An example: 

145 

146 ``` 

147 class imex_cupy_mesh(CuPyMultiComponentMesh): 

148 components = ['impl', 'expl'] 

149 ``` 

150 

151 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access 

152 to the components with ``.``. Continuing the above example: 

153 

154 ``` 

155 init = ((100,), None, numpy.dtype('d')) 

156 f = imex_cupy_mesh(init) 

157 f.shape # (2, 100) 

158 f.expl.shape # (100,) 

159 ``` 

160 

161 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the 

162 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view 

163 on to the component if appropriate. Importantly, this means that you cannot name a component like something that 

164 is already an attribute of ``cupy_mesh`` or ``cupy.ndarray`` because this will not result in calls to ``__getattr__``. 

165 

166 There are a couple more things to keep in mind: 

167 - Because a ``CuPyMultiComponentMesh`` is just a ``cupy.ndarray`` with one more dimension, all components must have 

168 the same shape. 

169 - You can use the entire ``CuPyMultiComponentMesh`` like a ``cupy.ndarray`` in operations that accept arrays, but make 

170 sure that you really want to apply the same operation on all components if you do. 

171 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this 

172 leads to all kinds of trouble throughout the code. But here you really cannot get away without. 

173 """ 

174 

175 components = [] 

176 

177 def __new__(cls, init, *args, **kwargs): 

178 if isinstance(init, tuple): 

179 shape = (init[0],) if type(init[0]) is int else init[0] 

180 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs) 

181 else: 

182 obj = super().__new__(cls, init, *args, **kwargs) 

183 

184 return obj 

185 

186 def __getattr__(self, name): 

187 if name in self.components: 

188 if self.shape[0] == len(self.components): 

189 return self[self.components.index(name)].view(cupy_mesh) 

190 else: 

191 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.') 

192 else: 

193 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!") 

194 

195 

196class imex_cupy_mesh(CuPyMultiComponentMesh): 

197 components = ['impl', 'expl'] 

198 

199 

200class comp2_cupy_mesh(CuPyMultiComponentMesh): 

201 components = ['comp1', 'comp2']