Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Architecture slices in halo filling #370

Merged
merged 4 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions PyMPDATA/boundary_conditions/constant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" boundary condition filling halos with a constant value """
# pylint: disable=too-many-arguments
from functools import lru_cache

import numba

from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Constant:
"""class which instances are to be passed in boundary_conditions tuple to the
Expand All @@ -12,19 +15,19 @@ class Constant:
def __init__(self, value):
self.value = value

def make_scalar(self, ats, halo, dtype, jit_flags):
def make_scalar(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_constant(self.value, ats, halo, dtype, jit_flags)
return _make_scalar_constant(self.value, ats, set_value, halo, dtype, jit_flags)

def make_vector(self, ats, halo, dtype, jit_flags):
def make_vector(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_scalar_constant(self.value, ats, halo, dtype, jit_flags)
return _make_scalar_constant(self.value, ats, set_value, halo, dtype, jit_flags)


@lru_cache()
def _make_scalar_constant(value, _, __, ___, jit_flags):
def _make_scalar_constant(value, _, set_value, __, ___, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(_1, _2, _3):
return value

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
20 changes: 12 additions & 8 deletions PyMPDATA/boundary_conditions/extrapolated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" boundary condition extrapolating values from the edge to the halo """
# pylint: disable=too-many-arguments
from functools import lru_cache

import numba
Expand All @@ -10,6 +11,7 @@
META_AND_DATA_META,
SIGN_LEFT,
)
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Extrapolated:
Expand All @@ -21,20 +23,22 @@ def __init__(self, dim=INNER, eps=1e-10):
self.eps = eps
self.dim = dim

def make_scalar(self, ats, halo, dtype, jit_flags):
def make_scalar(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_extrapolated(
self.dim, self.eps, ats, halo, dtype, jit_flags
self.dim, self.eps, ats, set_value, halo, dtype, jit_flags
)

def make_vector(self, ats, halo, dtype, jit_flags):
def make_vector(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_extrapolated(self.dim, ats, halo, dtype, jit_flags)
return _make_vector_extrapolated(
self.dim, ats, set_value, halo, dtype, jit_flags
)


@lru_cache()
# pylint: disable=too-many-arguments
def _make_scalar_extrapolated(dim, eps, ats, halo, dtype, jit_flags):
def _make_scalar_extrapolated(dim, eps, ats, set_value, halo, dtype, jit_flags):
@numba.njit(**jit_flags)
def impl(psi, span, sign):
if sign == SIGN_LEFT:
Expand Down Expand Up @@ -68,13 +72,13 @@ def fill_halos_scalar(psi, span, sign):
def fill_halos_scalar(psi, span, sign):
return impl(psi, span, sign)

return fill_halos_scalar
return make_fill_halos_loop(jit_flags, set_value, fill_halos_scalar)


@lru_cache()
def _make_vector_extrapolated(_, ats, __, ___, jit_flags):
def _make_vector_extrapolated(_, ats, set_value, __, ___, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ____, sign):
return ats(*psi, sign)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
25 changes: 13 additions & 12 deletions PyMPDATA/boundary_conditions/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba

from PyMPDATA.impl.enumerations import SIGN_LEFT, SIGN_RIGHT
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Periodic:
Expand All @@ -16,29 +17,29 @@ def __init__(self):
assert SIGN_LEFT == +1

@staticmethod
def make_scalar(ats, _, __, jit_flags):
def make_scalar(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_periodic(ats, jit_flags)
return _make_scalar_periodic(ats, set_value, jit_flags)

@staticmethod
def make_vector(ats, _, __, jit_flags):
def make_vector(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_periodic(ats, jit_flags)
return _make_vector_periodic(ats, set_value, jit_flags)


@lru_cache()
def _make_scalar_periodic(ats, jit_flags):
def _make_scalar_periodic(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, span, sign):
return ats(*psi, sign * span)
def fill_halos(focus_psi, span, sign):
return ats(*focus_psi, sign * span)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)


@lru_cache()
def _make_vector_periodic(ats, jit_flags):
def _make_vector_periodic(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, span, sign):
return ats(*psi, sign * span)
def fill_halos(focus_psi, span, sign):
return ats(*focus_psi, sign * span)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
13 changes: 7 additions & 6 deletions PyMPDATA/boundary_conditions/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba

from PyMPDATA.impl.enumerations import ARG_FOCUS, SIGN_LEFT, SIGN_RIGHT
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Polar:
Expand All @@ -22,7 +23,7 @@ def __init__(self, grid, longitude_idx, latitude_idx):
self.lon_idx = longitude_idx
self.lat_idx = latitude_idx

def make_scalar(self, ats, halo, _, jit_flags):
def make_scalar(self, ats, set_value, halo, _, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
nlon_half = self.nlon_half
nlat = self.nlat
Expand All @@ -43,18 +44,18 @@ def fill_halos(psi, _, sign):
val = nlon_half * (-1 if lon > nlon_half else 1)
return ats(*psi, sign * step, val)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)

@staticmethod
def make_vector(ats, _, __, jit_flags):
def make_vector(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_polar(ats, jit_flags)
return _make_vector_polar(ats, set_value, jit_flags)


@lru_cache()
def _make_vector_polar(ats, jit_flags):
def _make_vector_polar(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ___, ____):
return ats(*psi, 0) # TODO #120

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
1 change: 1 addition & 0 deletions PyMPDATA/impl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def assemble(self, traversals):
self.__impl = (self.__properties.meta, *self._impl_data), tuple(
getattr(fill_halos, method)(
traversals.indexers[self.n_dims].ats[i],
traversals.indexers[self.n_dims].set,
self.halo,
self.dtype,
traversals.jit_flags,
Expand Down
2 changes: 1 addition & 1 deletion PyMPDATA/impl/traversals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, *, grid, halo, jit_flags, n_threads, left_first):
)

common_kwargs = {
"indexers": self.indexers,
"jit_flags": jit_flags,
"halo": halo,
"n_dims": self.n_dims,
Expand All @@ -59,6 +58,7 @@ def __init__(self, *, grid, halo, jit_flags, n_threads, left_first):
}
common_kwargs = {
**common_kwargs,
"indexers": self.indexers,
**{
"boundary_cond_vector": self._code["fill_halos_vector"],
"boundary_cond_scalar": self._code["fill_halos_scalar"],
Expand Down
19 changes: 18 additions & 1 deletion PyMPDATA/impl/traversals_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
""" commons for scalar and vector field traversals """
# pylint: disable=too-many-arguments
import numba

from .enumerations import OUTER, RNG_STOP


def _make_common(jit_flags, spanner, chunker):
def make_common(jit_flags, spanner, chunker):
"""returns Numba-compiled callable producing common parameters"""

@numba.njit(**jit_flags)
def common(meta, thread_id):
span = spanner(meta)
Expand All @@ -14,3 +17,17 @@ def common(meta, thread_id):
return span, rng_outer, last_thread, first_thread

return common


def make_fill_halos_loop(jit_flags, set_value, fill_halos):
"""returns Numba-compiled halo-filling callable"""

@numba.njit(**jit_flags)
def fill_halos_loop(i_rng, j_rng, k_rng, psi, span, sign):
for i in i_rng:
for j in j_rng:
for k in k_rng:
focus = (i, j, k)
set_value(psi, *focus, fill_halos((focus, psi), span, sign))

return fill_halos_loop
Loading