Skip to content

Commit 1083484

Browse files
committed
Adding mpifft.destroy to some tests when cleaning up
1 parent fb22b87 commit 1083484

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

mpi4py_fft/io/file_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
comm = MPI.COMM_WORLD
77

88
class FileBase(object):
9-
"""Base class for reading/writing structure arrays
9+
"""Base class for reading/writing distributed arrays
1010
1111
Parameters
1212
----------

mpi4py_fft/mpifft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ def __init__(self, comm, shape=None, axes=None, dtype=float, slab=False,
323323
self.pencil[::-1])
324324

325325
def destroy(self):
326-
self.subcomm.destroy()
326+
if isinstance(self.subcomm, Subcomm):
327+
self.subcomm.destroy()
327328
for trans in self.transfer:
328329
trans.destroy()
329330

tests/test_darray.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,22 @@ def test_newDistArray():
107107
for view in (True, False):
108108
for rank in (0, 1, 2):
109109
a = newDistArray(pfft, forward_output=forward_output,
110-
rank=rank, view=view)
110+
rank=rank, view=view)
111111
if view is False:
112112
assert isinstance(a, DistArray)
113113
assert a.rank == rank
114114
if rank == 0:
115115
qfft = PFFT(MPI.COMM_WORLD, darray=a)
116116
elif rank == 1:
117117
qfft = PFFT(MPI.COMM_WORLD, darray=a[0])
118+
else:
119+
qfft = PFFT(MPI.COMM_WORLD, darray=a[0, 0])
120+
qfft.destroy()
121+
118122
else:
119123
assert isinstance(a, np.ndarray)
120124
assert a.base.rank == rank
125+
pfft.destroy()
121126

122127
if __name__ == '__main__':
123128
test_1Darray()

tests/test_io.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_2D(backend, forward_output):
5656
u0.read(filename, 'u', 2)
5757
u0.read(read, 'u', 2)
5858
assert np.allclose(u0, u)
59+
T.destroy()
5960

6061
def test_3D(backend, forward_output):
6162
if backend == 'netcdf4':
@@ -118,6 +119,7 @@ def test_3D(backend, forward_output):
118119
assert np.allclose(u0, u)
119120
read.read(u0, 'v', step=0)
120121
assert np.allclose(u0, v)
122+
T.destroy()
121123

122124
def test_4D(backend, forward_output):
123125
if backend == 'netcdf4':
@@ -153,6 +155,7 @@ def test_4D(backend, forward_output):
153155
assert np.allclose(u0, u)
154156
read.read(u0, 'v', step=0)
155157
assert np.allclose(u0, v)
158+
T.destroy()
156159

157160
if __name__ == '__main__':
158161
#pylint: disable=unused-import

0 commit comments

Comments
 (0)