Skip to content

Commit

Permalink
Halo exchanges nh solve (#302)
Browse files Browse the repository at this point in the history
Adds halo exchanges to solve_nonhydro making it runnable on several nodes.

* add convenience method exchange_and_wait in protocol of the `ExchangeRuntime`
* remove unnecessary domain bounds on `mo_solve_nonhydro_stencil_66`

includes some cleanups for `test_solve_nonhydro.py`
* remove SimpleMesh from test_solve_nonhydro.py
* predictor test: extract domain boundaries for Cells, remove total sizes magic numbers
* predictor test: extract domain boundaries for Edges remove total sizes magic numbers
* split `istep` , `jstep` fixtures into separate ones for `init` and `exit`

rename `solve_nonhydro.py`
  • Loading branch information
halungge authored Nov 8, 2023
1 parent d8bf0a3 commit d2ae0e5
Show file tree
Hide file tree
Showing 17 changed files with 524 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,12 @@ def _sync_cell_fields(self, prognostic_state):
IF ( .NOT. lhdiff_rcf .OR. linit .OR. (iforcing /= inwp .AND. iforcing /= iaes) ) THEN
"""
log.debug("communication of prognostic cell fields: theta, w, exner - start")
handle_cell_comm = self._exchange.exchange(
self._exchange.exchange_and_wait(
CellDim,
prognostic_state.w,
prognostic_state.theta_v,
prognostic_state.exner,
)
handle_cell_comm.wait()
log.debug("communication of prognostic cell fields: theta, w, exner - done")

def _do_diffusion_step(
Expand Down Expand Up @@ -618,8 +617,7 @@ def _do_diffusion_step(

# 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert
log.debug("communication rbf extrapolation of vn - start")
h = self._exchange.exchange(VertexDim, self.u_vert, self.v_vert)
h.wait()
h = self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
Expand All @@ -646,7 +644,7 @@ def _do_diffusion_step(
vertical_end=klevels,
offset_provider={
"E2C2V": self.grid.get_offset_provider("E2C2V"),
"E2ECV": self.grid.get_offset_provider("E2C2V"),
"E2ECV": self.grid.get_offset_provider("E2ECV"),
},
)
log.debug("running stencil 01 (calculate_nabla2_and_smag_coefficients_for_vn): end")
Expand Down Expand Up @@ -676,8 +674,7 @@ def _do_diffusion_step(

if self.config.type_vn_diffu > 1:
log.debug("communication rbf extrapolation of z_nable2_e - start")
h_z = self._exchange.exchange(EdgeDim, self.z_nabla2_e)
h_z.wait()
self._exchange.exchange_and_wait(EdgeDim, self.z_nabla2_e)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
Expand All @@ -697,8 +694,7 @@ def _do_diffusion_step(

# 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields)
log.debug("communication rbf extrapolation of z_nable2_e - start")
h = self._exchange.exchange(VertexDim, self.u_vert, self.v_vert)
h.wait()
self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
Expand Down Expand Up @@ -761,7 +757,7 @@ def _do_diffusion_step(
vertical_start=0,
vertical_end=klevels,
offset_provider={
"C2E2CO": self.grid.get_offset_provider("C2E2CO")(),
"C2E2CO": self.grid.get_offset_provider("C2E2CO"),
},
)
log.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
@pytest.mark.mpi
@pytest.mark.parametrize("ndyn_substeps", [2])
@pytest.mark.parametrize("linit", [True, False])
@pytest.mark.parametrize("processor_props", [True], indirect=True)
def test_parallel_diffusion(
r04b09_diffusion_config,
step_date_init,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, exp, int32, log, where
from gt4py.next.ffront.fbuiltins import Field, exp, log, where

from icon4py.model.common.dimension import CellDim, KDim

Expand All @@ -40,10 +40,6 @@ def mo_solve_nonhydro_stencil_66(
exner: Field[[CellDim, KDim], float],
rd_o_cvd: float,
rd_o_p0ref: float,
horizontal_start: int32,
horizontal_end: int32,
vertical_start: int32,
vertical_end: int32,
):
_mo_solve_nonhydro_stencil_66(
bdy_halo_c,
Expand All @@ -53,8 +49,4 @@ def mo_solve_nonhydro_stencil_66(
rd_o_cvd,
rd_o_p0ref,
out=(theta_v, exner),
domain={
CellDim: (horizontal_start, horizontal_end),
KDim: (vertical_start, vertical_end),
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
)
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
from icon4py.model.common.decomposition.definitions import ExchangeRuntime, SingleNodeExchange
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex
from icon4py.model.common.grid.icon import IconGrid
Expand Down Expand Up @@ -289,7 +290,8 @@ def __init__(self, config: NonHydrostaticConfig):


class SolveNonhydro:
def __init__(self):
def __init__(self, exchange: ExchangeRuntime = SingleNodeExchange()):
self._exchange = exchange
self._initialized = False
self.grid: Optional[IconGrid] = None
self.config: Optional[NonHydrostaticConfig] = None
Expand Down Expand Up @@ -493,18 +495,14 @@ def time_step(
lprep_adv=lprep_adv,
)

start_cell_local_minus1 = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.local(CellDim) - 1
)
end_cell_local = self.grid.get_end_index(CellDim, HorizontalMarkerIndex.local(CellDim))

start_cell_lb = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim)
)
end_cell_nudging_minus1 = self.grid.get_end_index(
CellDim, HorizontalMarkerIndex.nudging(CellDim) - 1
)

start_cell_halo = self.grid.get_start_index(CellDim, HorizontalMarkerIndex.halo(CellDim))
end_cell_end = self.grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim))
if self.grid.limited_area:
mo_solve_nonhydro_stencil_66.with_backend(run_gtfn)(
bdy_halo_c=self.metric_state_nonhydro.bdy_halo_c,
Expand All @@ -513,10 +511,6 @@ def time_step(
exner=prognostic_state_ls[nnew].exner,
rd_o_cvd=self.params.rd_o_cvd,
rd_o_p0ref=self.params.rd_o_p0ref,
horizontal_start=start_cell_local_minus1,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

Expand All @@ -542,8 +536,8 @@ def time_step(
rho_new=prognostic_state_ls[nnew].rho,
theta_v_new=prognostic_state_ls[nnew].theta_v,
cvd_o_rd=self.params.cvd_o_rd,
horizontal_start=start_cell_local_minus1,
horizontal_end=end_cell_local,
horizontal_start=start_cell_halo,
horizontal_end=end_cell_end,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
Expand All @@ -562,6 +556,7 @@ def run_predictor_step(
nnow: int,
nnew: int,
):
# TODO (magdalena) fix this! when fixing call of velocity advection in predictor, condition is broken,
if l_init or l_recompute:
if self.config.itime_scheme == 4 and not l_init:
lvn_only = True # Recompute only vn tendency
Expand Down Expand Up @@ -973,16 +968,17 @@ def run_predictor_step(
"Koff": KDim,
},
)
# TODO (magdalena) what is 64?
self.z_hydro_corr_horizontal = np_as_located_field(EdgeDim)(
np.asarray(self.z_hydro_corr)[:, 64]
# TODO (Nikki) check when merging fused stencil
lowest_level = self.grid.num_levels - 1
hydro_corr_horizontal = np_as_located_field(EdgeDim)(
np.asarray(self.z_hydro_corr)[:, lowest_level]
)

if self.config.igradp_method == 3:
mo_solve_nonhydro_stencil_22.with_backend(run_gtfn)(
ipeidx_dsl=self.metric_state_nonhydro.ipeidx_dsl,
pg_exdist=self.metric_state_nonhydro.pg_exdist,
z_hydro_corr=self.z_hydro_corr_horizontal,
z_hydro_corr=hydro_corr_horizontal,
z_gradh_exner=z_fields.z_gradh_exner,
horizontal_start=start_edge_nudging_plus1,
horizontal_end=end_edge_end,
Expand Down Expand Up @@ -1032,7 +1028,7 @@ def run_predictor_step(
offset_provider={},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(EdgeDim, prognostic_state[nnew].vn, z_fields.z_rho_e)

mo_solve_nonhydro_stencil_30.with_backend(run_gtfn)(
e_flx_avg=self.interpolation_state.e_flx_avg,
Expand Down Expand Up @@ -1314,7 +1310,7 @@ def run_predictor_step(
offset_provider={},
)

if self.grid.limited_area: # for MPI-parallelized case
if self.grid.limited_area:
nhsolve_prog.stencils_61_62.with_backend(run_gtfn)(
rho_now=prognostic_state[nnow].rho,
grf_tend_rho=diagnostic_state_nh.grf_tend_rho,
Expand Down Expand Up @@ -1348,7 +1344,9 @@ def run_predictor_step(
offset_provider={"Koff": KDim},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(CellDim, prognostic_state[nnew].w, z_fields.z_dwdz_dd)
else:
self._exchange.exchange_and_wait(CellDim, prognostic_state[nnew].w)

def run_corrector_step(
self,
Expand Down Expand Up @@ -1562,8 +1560,7 @@ def run_corrector_step(
offset_provider={},
)

# COMMUNICATION PHASE

self._exchange.exchange_and_wait(EdgeDim, (prognostic_state[nnew].vn))
mo_solve_nonhydro_stencil_31.with_backend(run_gtfn)(
e_flx_avg=self.interpolation_state.e_flx_avg,
vn=prognostic_state[nnew].vn,
Expand Down Expand Up @@ -1872,4 +1869,9 @@ def run_corrector_step(
offset_provider={},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(
CellDim,
prognostic_state[nnew].rho,
prognostic_state[nnew].exner,
prognostic_state[nnew].w,
)
6 changes: 4 additions & 2 deletions model/atmosphere/dycore/tests/dycore_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
grid_savepoint,
icon_grid,
interpolation_savepoint,
istep,
jstep,
istep_exit,
istep_init,
jstep_exit,
jstep_init,
linit,
metrics_savepoint,
processor_props,
Expand Down
Loading

0 comments on commit d2ae0e5

Please sign in to comment.