diff --git a/PyMPDATA_MPI/periodic.py b/PyMPDATA_MPI/mpi_periodic.py similarity index 62% rename from PyMPDATA_MPI/periodic.py rename to PyMPDATA_MPI/mpi_periodic.py index 90a7b16..fa52979 100644 --- a/PyMPDATA_MPI/periodic.py +++ b/PyMPDATA_MPI/mpi_periodic.py @@ -5,7 +5,6 @@ import numba import numba_mpi as mpi -import numpy as np from mpi4py import MPI from PyMPDATA.boundary_conditions import Periodic from PyMPDATA.impl.enumerations import INVALID_INDEX, SIGN_LEFT, SIGN_RIGHT @@ -30,7 +29,9 @@ def make_scalar(self, indexers, halo, dtype, jit_flags, dimension_index): return Periodic.make_scalar( indexers, halo, dtype, jit_flags, dimension_index ) - return _make_scalar_periodic(indexers, jit_flags, dimension_index, self.__size) + return _make_scalar_periodic( + indexers, jit_flags, dimension_index, self.__size, dtype + ) def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index): """returns (lru-cached) Numba-compiled vector halo-filling callable""" @@ -39,34 +40,27 @@ def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index): indexers, halo, dtype, jit_flags, dimension_index ) return _make_vector_periodic( - indexers, halo, jit_flags, dimension_index, self.__size + indexers, halo, jit_flags, dimension_index, self.__size, dtype ) -def _make_send_recv(set_value, jit_flags, fill_buf): +def _make_send_recv(set_value, jit_flags, fill_buf, size, dtype): @numba.njit(**jit_flags) - def _send_recv(size, psi, i_rng, j_rng, k_rng, sign, dim, output): - buf = np.empty( - ( - len(i_rng), - len(k_rng), - ), - dtype=output.dtype, - ) + def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index): + chunk_size = len(i_rng) * len(k_rng) + return buffer.view(dtype)[ + chunk_index * chunk_size : (chunk_index + 1) * chunk_size + ].reshape((len(i_rng), len(k_rng))) + @numba.njit(**jit_flags) + def get_peers(): rank = mpi.rank() - peers = (-1, (rank - 1) % size, (rank + 1) % size) # LEFT # RIGHT - - if SIGN_LEFT == sign: - fill_buf(buf, psi, i_rng, k_rng, sign, dim) - mpi.send(buf, dest=peers[sign]) - mpi.recv(buf, source=peers[sign]) - elif SIGN_RIGHT == sign: - mpi.recv(buf, source=peers[sign]) - tmp = np.empty_like(buf) - fill_buf(tmp, psi, i_rng, k_rng, sign, dim) - mpi.send(tmp, dest=peers[sign]) + left_peer = (rank - 1) % size + right_peer = (rank + 1) % size + return (-1, left_peer, right_peer) + @numba.njit(**jit_flags) + def fill_output(output, buffer, i_rng, j_rng, k_rng): for i in i_rng: for j in j_rng: for k in k_rng: @@ -75,14 +69,39 @@ def _send_recv(size, psi, i_rng, j_rng, k_rng, sign, dim, output): i, j, k, - buf[i - i_rng.start, k - k_rng.start], + buffer[i - i_rng.start, k - k_rng.start], ) + @numba.njit(**jit_flags) + def _send(buf, peer, fill_buf_args): + fill_buf(buf, *fill_buf_args) + mpi.send(buf, dest=peer) + + @numba.njit(**jit_flags) + def _recv(buf, peer): + mpi.recv(buf, source=peer) + + @numba.njit(**jit_flags) + def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output): + buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0) + peers = get_peers() + fill_buf_args = (psi, i_rng, k_rng, sign, dim) + + if SIGN_LEFT == sign: + _send(buf=buf, peer=peers[sign], fill_buf_args=fill_buf_args) + _recv(buf=buf, peer=peers[sign]) + elif SIGN_RIGHT == sign: + _recv(buf=buf, peer=peers[sign]) + tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1) + _send(buf=tmp, peer=peers[sign], fill_buf_args=fill_buf_args) + + fill_output(output, buf, i_rng, j_rng, k_rng) + return _send_recv @lru_cache() -def _make_scalar_periodic(indexers, jit_flags, dimension_index, size): +def _make_scalar_periodic(indexers, jit_flags, dimension_index, size, dtype): @numba.njit(**jit_flags) def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): for i in i_rng: @@ -91,17 +110,17 @@ def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): (i, INVALID_INDEX, k), psi, sign ) - send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf) + send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf, size, dtype) @numba.njit(**jit_flags) - def fill_halos(i_rng, j_rng, k_rng, psi, _, sign): - send_recv(size, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) + def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): + send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) return fill_halos @lru_cache() -def _make_vector_periodic(indexers, halo, jit_flags, dimension_index, size): +def _make_vector_periodic(indexers, halo, jit_flags, dimension_index, size, dtype): @numba.njit(**jit_flags) def fill_buf(buf, components, i_rng, k_rng, sign, dim): parallel = dim % len(components) == dimension_index @@ -119,12 +138,12 @@ def fill_buf(buf, components, i_rng, k_rng, sign, dim): buf[i - i_rng.start, k - k_rng.start] = value - send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf) + send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf, size, dtype) @numba.njit(**jit_flags) - def fill_halos_loop_vector(i_rng, j_rng, k_rng, components, dim, _, sign): + def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: return - send_recv(size, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) + send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) return fill_halos_loop_vector diff --git a/PyMPDATA_MPI/simulation.py b/PyMPDATA_MPI/simulation.py index 9bb567b..001b9f7 100644 --- a/PyMPDATA_MPI/simulation.py +++ b/PyMPDATA_MPI/simulation.py @@ -5,7 +5,7 @@ from PyMPDATA.boundary_conditions import Periodic from .domain_decomposition import mpi_indices -from .periodic import MPIPeriodic +from .mpi_periodic import MPIPeriodic class Simulation: @@ -45,6 +45,10 @@ def __init__( n_dims=2, n_threads=n_threads, left_first=rank % 2 == 0, + # TODO https://github.com/open-atmos/PyMPDATA/issues/386 + buffer_size=((ny + 2 * halo) * halo) + * 2 # for temporary send/recv buffer on one side + * 2, # for complex dtype ) self.solver = Solver(stepper=stepper, advectee=self.advectee, advector=advector) diff --git a/pyproject.toml b/pyproject.toml index b987727..1f375a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,11 +20,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Physics" ] dependencies = [ + # TODO: these should be handled within PyMPDATA? + "numba<0.57.0", + "numpy<1.24.0", "numba_mpi>=0.30", - "PyMPDATA==1.0.10", + "PyMPDATA==1.0.11", "mpi4py", "h5py", - "pytest-mpi" + "pytest-mpi" # TODO: move it to optional dependencies (extras_require?) ] dynamic = ["version"]