Skip to content

Commit

Permalink
use make_fill_halos_loop in Extrapolated and Polar
Browse files Browse the repository at this point in the history
  • Loading branch information
Delcior committed Feb 20, 2023
1 parent 4c0b111 commit 09fcae5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
20 changes: 12 additions & 8 deletions PyMPDATA/boundary_conditions/extrapolated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" boundary condition extrapolating values from the edge to the halo """
# pylint: disable=too-many-arguments
from functools import lru_cache

import numba
Expand All @@ -10,6 +11,7 @@
META_AND_DATA_META,
SIGN_LEFT,
)
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Extrapolated:
Expand All @@ -21,20 +23,22 @@ def __init__(self, dim=INNER, eps=1e-10):
self.eps = eps
self.dim = dim

def make_scalar(self, ats, halo, dtype, jit_flags):
def make_scalar(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
return _make_scalar_extrapolated(
self.dim, self.eps, ats, halo, dtype, jit_flags
self.dim, self.eps, ats, set_value, halo, dtype, jit_flags
)

def make_vector(self, ats, halo, dtype, jit_flags):
def make_vector(self, ats, set_value, halo, dtype, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_extrapolated(self.dim, ats, halo, dtype, jit_flags)
return _make_vector_extrapolated(
self.dim, ats, set_value, halo, dtype, jit_flags
)


@lru_cache()
# pylint: disable=too-many-arguments
def _make_scalar_extrapolated(dim, eps, ats, halo, dtype, jit_flags):
def _make_scalar_extrapolated(dim, eps, ats, set_value, halo, dtype, jit_flags):
@numba.njit(**jit_flags)
def impl(psi, span, sign):
if sign == SIGN_LEFT:
Expand Down Expand Up @@ -68,13 +72,13 @@ def fill_halos_scalar(psi, span, sign):
def fill_halos_scalar(psi, span, sign):
return impl(psi, span, sign)

return fill_halos_scalar
return make_fill_halos_loop(jit_flags, set_value, fill_halos_scalar)


@lru_cache()
def _make_vector_extrapolated(_, ats, __, ___, jit_flags):
def _make_vector_extrapolated(_, ats, set_value, __, ___, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ____, sign):
return ats(*psi, sign)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
13 changes: 7 additions & 6 deletions PyMPDATA/boundary_conditions/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba

from PyMPDATA.impl.enumerations import ARG_FOCUS, SIGN_LEFT, SIGN_RIGHT
from PyMPDATA.impl.traversals_common import make_fill_halos_loop


class Polar:
Expand All @@ -22,7 +23,7 @@ def __init__(self, grid, longitude_idx, latitude_idx):
self.lon_idx = longitude_idx
self.lat_idx = latitude_idx

def make_scalar(self, ats, halo, _, jit_flags):
def make_scalar(self, ats, set_value, halo, _, jit_flags):
"""returns (lru-cached) Numba-compiled scalar halo-filling callable"""
nlon_half = self.nlon_half
nlat = self.nlat
Expand All @@ -43,18 +44,18 @@ def fill_halos(psi, _, sign):
val = nlon_half * (-1 if lon > nlon_half else 1)
return ats(*psi, sign * step, val)

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)

@staticmethod
def make_vector(ats, _, __, jit_flags):
def make_vector(ats, set_value, _, __, jit_flags):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
return _make_vector_polar(ats, jit_flags)
return _make_vector_polar(ats, set_value, jit_flags)


@lru_cache()
def _make_vector_polar(ats, jit_flags):
def _make_vector_polar(ats, set_value, jit_flags):
@numba.njit(**jit_flags)
def fill_halos(psi, ___, ____):
return ats(*psi, 0) # TODO #120

return fill_halos
return make_fill_halos_loop(jit_flags, set_value, fill_halos)
2 changes: 2 additions & 0 deletions PyMPDATA/impl/traversals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def make_common(jit_flags, spanner, chunker):
"""returns Numba-compiled callable producing common parameters"""

@numba.njit(**jit_flags)
def common(meta, thread_id):
span = spanner(meta)
Expand All @@ -20,6 +21,7 @@ def common(meta, thread_id):

def make_fill_halos_loop(jit_flags, set_value, fill_halos):
"""returns Numba-compiled halo-filling callable"""

@numba.njit(**jit_flags)
def fill_halos_loop(i_rng, j_rng, k_rng, psi, span, sign):
for i in i_rng:
Expand Down

0 comments on commit 09fcae5

Please sign in to comment.