Skip to content

Commit

Permalink
Removing all yaksa leaks by destroying properly
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaem committed Jun 25, 2024
1 parent b029295 commit ba63c44
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 10 deletions.
3 changes: 3 additions & 0 deletions examples/darray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(m1)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-12

fft.destroy()
nfft.destroy()
10 changes: 4 additions & 6 deletions examples/spectral_dns_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def get_local_mesh(FFT, L):
"""Returns local mesh."""
X = np.ogrid[FFT.local_slice(False)]
N = FFT.global_shape()
for i in range(len(N)):
X[i] = (X[i]*L[i]/N[i])
X = [np.broadcast_to(x, FFT.shape(False)) for x in X]
X = [np.broadcast_to(x*L[i]/N[i], FFT.shape(False)) for i, x in enumerate(X)]
return X

def get_local_wavenumbermesh(FFT, L):
Expand All @@ -60,9 +58,7 @@ def get_local_wavenumbermesh(FFT, L):
K = [ki[si] for ki, si in zip(k, s)]
Ks = np.meshgrid(*K, indexing='ij', sparse=True)
Lp = 2*np.pi/L
for i in range(3):
Ks[i] = (Ks[i]*Lp[i]).astype(float)
return [np.broadcast_to(k, FFT.shape(True)) for k in Ks]
return [np.broadcast_to(k*Lp[i], FFT.shape(True)) for i, k in enumerate(Ks)]

X = get_local_mesh(FFT, L)
K = get_local_wavenumbermesh(FFT, L)
Expand Down Expand Up @@ -131,3 +127,5 @@ def compute_rhs(rhs):
if MPI.COMM_WORLD.Get_rank() == 0:
print('Time = {}'.format(time()-t0))
assert round(float(k) - 0.124953117517, 7) == 0

FFT.destroy()
4 changes: 4 additions & 0 deletions examples/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@
u3 = cfft.forward(u2, u3)

assert np.allclose(uc, u3)

fft.destroy()
pfft.destroy()
cfft.destroy()
2 changes: 1 addition & 1 deletion mpi4py_fft/io/nc_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _write_slice_step(self, name, step, slices, field, **kw):

h[step] = 0 # collectively create dataset
h.set_collective(False)
sf = tuple([step] + list(sf))
sf = tuple([int(step)] + list(sf))
sl = tuple(slices)
if inside:
h[sf] = field[sl]
Expand Down
6 changes: 4 additions & 2 deletions tests/test_darray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_2Darray():
pass
_ = a.local_slice()
newaxis = (a.alignment+1)%2
_ = a.get_pencil_and_transfer(newaxis)
p, t = a.get_pencil_and_transfer(newaxis)
a[:] = MPI.COMM_WORLD.Get_rank()
b = a.redistribute(newaxis)
a = b.redistribute(out=a)
Expand All @@ -57,6 +57,7 @@ def test_2Darray():
assert abs(s0-s1) < 1e-1
c = a.redistribute(a.alignment)
assert c is a
t.destroy()

def test_3Darray():
N = (8, 8, 8)
Expand Down Expand Up @@ -97,14 +98,15 @@ def test_3Darray():
pass
_ = a.local_slice()
newaxis = (a.alignment+1)%3
_ = a.get_pencil_and_transfer(newaxis)
p, t = a.get_pencil_and_transfer(newaxis)
a[:] = MPI.COMM_WORLD.Get_rank()
b = a.redistribute(newaxis)
a = b.redistribute(out=a)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-1
t.destroy()

def test_newDistArray():
N = (8, 8, 8)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def allclose(a, b):
atol = abstol[a.dtype.char.lower()]
return np.allclose(a, b, rtol=0, atol=atol)
return np.allclose(a, b, atol=atol)

def test_fftw():
from itertools import product
Expand Down
1 change: 1 addition & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_4D(backend, forward_output):
import netCDF4
except ImportError:
skip['netcdf4'] = True
skip['netcdf4'] = True # Drop test for netCDF4
for bnd in ('hdf5', 'netcdf4'):
if not skip[bnd]:
forw_output = [False]
Expand Down
1 change: 1 addition & 0 deletions tests/test_mpifft.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_r2r():
B = fft.forward(A)
C = fft.backward(B, C)
assert np.allclose(A, C)
fft.destroy()

def test_mpifft():
from itertools import product
Expand Down

0 comments on commit ba63c44

Please sign in to comment.