Skip to content

Commit

Permalink
Architecture slices in halo filling (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delcior authored Feb 23, 2023
1 parent b10ff60 commit b6e22a8
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 345 deletions.
15 changes: 9 additions & 6 deletions PyMPDATA/boundary_conditions/constant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" boundary condition filling halos with a constant value """
# pylint: disable=too-many-arguments
from functools import lru_cache

import numba

from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Constant:
"""class which instances are to be passed in boundary_conditions tuple to the
Expand All @@ -12,19 +15,19 @@ class Constant:
def __init__(self, value):
self.value = value

def make_scalar(self, ats, halo, dtype, jit_flags):
def make_scalar(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_constant(self.value, ats, halo, dtype, jit_flags)
return _make_scalar_constant(self.value, ats, set_value, halo, dtype, jit_flags)

def make_vector(self, ats, halo, dtype, jit_flags):
def make_vector(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_scalar_constant(self.value, ats, halo, dtype, jit_flags)
return _make_scalar_constant(self.value, ats, set_value, halo, dtype, jit_flags)


@lru_cache()
def _make_scalar_constant(value, _, __, ___, jit_flags):
def _make_scalar_constant(value, _, set_value, __, ___, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(_1, _2, _3):
return value

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
20 changes: 12 additions & 8 deletions PyMPDATA/boundary_conditions/extrapolated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" boundary condition extrapolating values from the edge to the halo """
# pylint: disable=too-many-arguments
from functools import lru_cache

import numba
Expand All @@ -10,6 +11,7 @@
META_AND_DATA_META,
SIGN_LEFT,
)
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Extrapolated:
Expand All @@ -21,20 +23,22 @@ def __init__(self, dim=INNER, eps=1e-10):
self.eps = eps
self.dim = dim

def make_scalar(self, ats, halo, dtype, jit_flags):
def make_scalar(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_extrapolated(
self.dim, self.eps, ats, halo, dtype, jit_flags
self.dim, self.eps, ats, set_value, halo, dtype, jit_flags
)

def make_vector(self, ats, halo, dtype, jit_flags):
def make_vector(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_extrapolated(self.dim, ats, halo, dtype, jit_flags)
return _make_vector_extrapolated(
self.dim, ats, set_value, halo, dtype, jit_flags
)


@lru_cache()
# pylint: disable=too-many-arguments
def _make_scalar_extrapolated(dim, eps, ats, halo, dtype, jit_flags):
def _make_scalar_extrapolated(dim, eps, ats, set_value, halo, dtype, jit_flags):
@numba.njit(**jit_flags)
def impl(psi, span, sign):
if sign == SIGN_LEFT:
Expand Down Expand Up @@ -68,13 +72,13 @@ def fill_halos_scalar(psi, span, sign):
def fill_halos_scalar(psi, span, sign):
return impl(psi, span, sign)

return fill_halos_scalar
return make_fill_halos_loop(jit_flags, set_value, fill_halos_scalar)


@lru_cache()
def _make_vector_extrapolated(_, ats, __, ___, jit_flags):
def _make_vector_extrapolated(_, ats, set_value, __, ___, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ____, sign):
return ats(*psi, sign)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
25 changes: 13 additions & 12 deletions PyMPDATA/boundary_conditions/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba

from PyMPDATA.impl.enumerations import SIGN_LEFT, SIGN_RIGHT
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Periodic:
Expand All @@ -16,29 +17,29 @@ def __init__(self):
assert SIGN_LEFT == +1

@staticmethod
def make_scalar(ats, _, __, jit_flags):
def make_scalar(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_periodic(ats, jit_flags)
return _make_scalar_periodic(ats, set_value, jit_flags)

@staticmethod
def make_vector(ats, _, __, jit_flags):
def make_vector(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_periodic(ats, jit_flags)
return _make_vector_periodic(ats, set_value, jit_flags)


@lru_cache()
def _make_scalar_periodic(ats, jit_flags):
def _make_scalar_periodic(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, span, sign):
return ats(*psi, sign * span)
def fill_halos(focus_psi, span, sign):
return ats(*focus_psi, sign * span)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)


@lru_cache()
def _make_vector_periodic(ats, jit_flags):
def _make_vector_periodic(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, span, sign):
return ats(*psi, sign * span)
def fill_halos(focus_psi, span, sign):
return ats(*focus_psi, sign * span)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
13 changes: 7 additions & 6 deletions PyMPDATA/boundary_conditions/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba

from PyMPDATA.impl.enumerations import ARG_FOCUS, SIGN_LEFT, SIGN_RIGHT
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Polar:
Expand All @@ -22,7 +23,7 @@ def __init__(self, grid, longitude_idx, latitude_idx):
self.lon_idx = longitude_idx
self.lat_idx = latitude_idx

def make_scalar(self, ats, halo, _, jit_flags):
def make_scalar(self, ats, set_value, halo, _, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
nlon_half = self.nlon_half
nlat = self.nlat
Expand All @@ -43,18 +44,18 @@ def fill_halos(psi, _, sign):
val = nlon_half * (-1 if lon > nlon_half else 1)
return ats(*psi, sign * step, val)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)

@staticmethod
def make_vector(ats, _, __, jit_flags):
def make_vector(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_polar(ats, jit_flags)
return _make_vector_polar(ats, set_value, jit_flags)


@lru_cache()
def _make_vector_polar(ats, jit_flags):
def _make_vector_polar(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ___, ____):
return ats(*psi, 0) # TODO #120

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
1 change: 1 addition & 0 deletions PyMPDATA/impl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def assemble(self, traversals):
self.__impl = (self.__properties.meta, *self._impl_data), tuple(
getattr(fill_halos, method)(
traversals.indexers[self.n_dims].ats[i],
traversals.indexers[self.n_dims].set,
self.halo,
self.dtype,
traversals.jit_flags,
Expand Down
2 changes: 1 addition & 1 deletion PyMPDATA/impl/traversals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, *, grid, halo, jit_flags, n_threads, left_first):
)

common_kwargs = {
"indexers": self.indexers,
"jit_flags": jit_flags,
"halo": halo,
"n_dims": self.n_dims,
Expand All @@ -59,6 +58,7 @@ def __init__(self, *, grid, halo, jit_flags, n_threads, left_first):
}
common_kwargs = {
**common_kwargs,
"indexers": self.indexers,
**{
"boundary_cond_vector": self._code["fill_halos_vector"],
"boundary_cond_scalar": self._code["fill_halos_scalar"],
Expand Down
19 changes: 18 additions & 1 deletion PyMPDATA/impl/traversals_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
""" commons for scalar and vector field traversals """
# pylint: disable=too-many-arguments
import numba

from .enumerations import OUTER, RNG_STOP


def _make_common(jit_flags, spanner, chunker):
def make_common(jit_flags, spanner, chunker):
"""returns Numba-compiled callable producing common parameters"""

@numba.njit(**jit_flags)
def common(meta, thread_id):
span = spanner(meta)
Expand All @@ -14,3 +17,17 @@ def common(meta, thread_id):
return span, rng_outer, last_thread, first_thread

return common


def make_fill_halos_loop(jit_flags, set_value, fill_halos):
"""returns Numba-compiled halo-filling callable"""

@numba.njit(**jit_flags)
def fill_halos_loop(i_rng, j_rng, k_rng, psi, span, sign):
for i in i_rng:
for j in j_rng:
for k in k_rng:
focus = (i, j, k)
set_value(psi, *focus, fill_halos((focus, psi), span, sign))

return fill_halos_loop
Loading

0 comments on commit b6e22a8

Please sign in to comment.