|
3 | 3 | import numpy as np |
4 | 4 | from mpi4py import MPI |
5 | 5 | from .pencil import Pencil, Subcomm |
| 6 | +from .io import HDF5File, NCFile, FileBase |
6 | 7 |
|
7 | 8 | comm = MPI.COMM_WORLD |
8 | 9 |
|
@@ -54,8 +55,12 @@ class DistArray(np.ndarray): |
54 | 55 | """ |
55 | 56 | def __new__(cls, global_shape, subcomm=None, val=None, dtype=np.float, |
56 | 57 | buffer=None, alignment=None, rank=0): |
57 | | - if rank > 0: |
58 | | - assert global_shape[:rank] == (len(global_shape[rank:]),)*rank |
| 58 | + if len(global_shape) < 2: |
| 59 | + obj = np.ndarray.__new__(cls, global_shape, dtype=dtype, buffer=buffer) |
| 60 | + if buffer is None and isinstance(val, Number): |
| 61 | + obj.fill(val) |
| 62 | + obj._rank = rank |
| 63 | + return obj |
59 | 64 |
|
60 | 65 | if isinstance(subcomm, Subcomm): |
61 | 66 | pass |
@@ -139,13 +144,24 @@ def rank(self): |
139 | 144 | """Return tensor rank of ``self``""" |
140 | 145 | return self._rank |
141 | 146 |
|
| 147 | + @property |
| 148 | + def dimensions(self): |
| 149 | + """Return dimensions of array not including rank""" |
| 150 | + return len(self._p0.shape) |
| 151 | + |
142 | 152 | def __getitem__(self, i): |
143 | 153 | # Return DistArray if the result is a component of a tensor |
144 | 154 | # Otherwise return ndarray view |
145 | | - if isinstance(i, int) and self.rank > 0: |
| 155 | + if self.ndim == 1: |
| 156 | + return np.ndarray.__getitem__(self, i) |
| 157 | + |
| 158 | + if isinstance(i, (int, slice)) and self.rank > 0: |
146 | 159 | v0 = np.ndarray.__getitem__(self, i) |
147 | | - v0._rank -= 1 |
| 160 | + v0._rank = self.rank - (self.ndim - v0.ndim) |
| 161 | + #if v0.ndim < self.ndim: |
| 162 | + # v0._rank -= 1 |
148 | 163 | return v0 |
| 164 | + |
149 | 165 | if isinstance(i, tuple) and len(i) == 2 and self.rank == 2: |
150 | 166 | v0 = np.ndarray.__getitem__(self, i) |
151 | 167 | v0._rank = 0 |
@@ -246,14 +262,14 @@ def local_slice(self): |
246 | 262 | ... print(l)''') |
247 | 263 | >>> fx.close() |
248 | 264 | >>> print(subprocess.getoutput('mpirun -np 4 python ls_script.py')) |
249 | | - [slice(0, 16, None), slice(0, 7, None), slice(0, 6, None)] |
250 | | - [slice(0, 16, None), slice(0, 7, None), slice(6, 12, None)] |
251 | | - [slice(0, 16, None), slice(7, 14, None), slice(0, 6, None)] |
252 | | - [slice(0, 16, None), slice(7, 14, None), slice(6, 12, None)] |
| 265 | + (slice(0, 16, None), slice(0, 7, None), slice(0, 6, None)) |
| 266 | + (slice(0, 16, None), slice(0, 7, None), slice(6, 12, None)) |
| 267 | + (slice(0, 16, None), slice(7, 14, None), slice(0, 6, None)) |
| 268 | + (slice(0, 16, None), slice(7, 14, None), slice(6, 12, None)) |
253 | 269 | """ |
254 | 270 | v = [slice(start, start+shape) for start, shape in zip(self._p0.substart, |
255 | 271 | self._p0.subshape)] |
256 | | - return [slice(0, s) for s in self.shape[:self.rank]] + v |
| 272 | + return tuple([slice(0, s) for s in self.shape[:self.rank]] + v) |
257 | 273 |
|
258 | 274 | def get_pencil_and_transfer(self, axis): |
259 | 275 | """Return pencil and transfer objects for alignment along ``axis`` |
@@ -339,6 +355,74 @@ def redistribute(self, axis=None, out=None): |
339 | 355 |
|
340 | 356 | return out |
341 | 357 |
|
| 358 | + def write(self, filename, name='darray', step=0, global_slice=None, |
| 359 | + as_scalar=False): |
| 360 | + """Write snapshot ``step`` of ``self`` to file ``filename`` |
| 361 | +
|
| 362 | + Parameters |
| 363 | + ---------- |
| 364 | + filename : str or instance of :class:`.FileBase` |
| 365 | + The name of the file (or the file itself) that is used to store the |
| 366 | + requested data in ``self`` |
| 367 | + name : str, optional |
| 368 | + Name used for storing snapshot in file. |
| 369 | + step : int, optional |
| 370 | + Index used for snapshot in file. |
| 371 | + global_slice : sequence of slices or integers, optional |
| 372 | + Store only this global slice of ``self`` |
| 373 | + as_scalar : boolean, optional |
| 374 | + Whether to store rank > 0 arrays as scalars. Default is False. |
| 375 | +
|
| 376 | + Example |
| 377 | + ------- |
| 378 | + >>> from mpi4py_fft import DistArray |
| 379 | + >>> u = DistArray((8, 8), val=1) |
| 380 | + >>> u.write('h5file.h5', 'u', 0) |
| 381 | + >>> u.write('h5file.h5', 'u', (slice(None), 4)) |
| 382 | + """ |
| 383 | + if isinstance(filename, str): |
| 384 | + writer = HDF5File if filename.endswith('.h5') else NCFile |
| 385 | + f = writer(filename, u=self, mode='a') |
| 386 | + elif isinstance(filename, FileBase): |
| 387 | + f = filename |
| 388 | + field = [self] if global_slice is None else [(self, global_slice)] |
| 389 | + f.write(step, {name: field}, as_scalar=as_scalar) |
| 390 | + |
| 391 | + def read(self, filename, name='darray', step=0): |
| 392 | + """Read from file ``filename`` into array ``self`` |
| 393 | +
|
| 394 | + Note |
| 395 | + ---- |
| 396 | + Only whole arrays can be read from file, not slices. |
| 397 | +
|
| 398 | + Parameters |
| 399 | + ---------- |
| 400 | + filename : str or instance of :class:`.FileBase` |
| 401 | + The name of the file (or the file itself) holding the data that is |
| 402 | + loaded into ``self``. |
| 403 | + name : str, optional |
| 404 | + Internal name in file of snapshot to be read. |
| 405 | + step : int, optional |
| 406 | + Index of field to be read. Default is 0. |
| 407 | +
|
| 408 | + Example |
| 409 | + ------- |
| 410 | + >>> from mpi4py_fft import DistArray |
| 411 | + >>> u = DistArray((8, 8), val=1) |
| 412 | + >>> u.write('h5file.h5', 'u', 0) |
| 413 | + >>> v = DistArray((8, 8)) |
| 414 | + >>> v.read('h5file.h5', 'u', 0) |
| 415 | + >>> assert np.allclose(u, v) |
| 416 | +
|
| 417 | + """ |
| 418 | + if isinstance(filename, str): |
| 419 | + writer = HDF5File if filename.endswith('.h5') else NCFile |
| 420 | + f = writer(filename, u=self, mode='r') |
| 421 | + elif isinstance(filename, FileBase): |
| 422 | + f = filename |
| 423 | + f.read(self, name, step=step) |
| 424 | + |
| 425 | + |
342 | 426 | def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False): |
343 | 427 | """Return a new :class:`.DistArray` object for provided :class:`.PFFT` object |
344 | 428 |
|
|
0 commit comments