Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Halo exchanges nh solve #302

Merged
merged 26 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7c19efd
add exchange to solve_nonydro.py
halungge Oct 12, 2023
132adef
add convenience method exchange_and_wait in protocol
halungge Oct 13, 2023
a1bc2bf
Merge branch 'main' into halo_exchanges_nh_solve
halungge Oct 13, 2023
5129040
Merge branch 'main' into halo_exchanges_nh_solve
halungge Oct 18, 2023
9fe8035
remove SimpleMesh from test_solve_nonhydro.py
halungge Oct 18, 2023
2bd2c24
merge main
halungge Oct 23, 2023
4b91541
predictor test: extract domain boundaries for Cells, remove total size
halungge Oct 23, 2023
5b2e343
predictor test: extract domain boundaries for Edges remove total size,
halungge Oct 23, 2023
15f3555
add test for parallel
halungge Oct 23, 2023
6321026
remove fields not registered in the savepoint
halungge Oct 24, 2023
f71d08a
WIP parallel tests
halungge Oct 24, 2023
3474aeb
split istep_init, istep_exit, jstep_init/jstep_exit make single node …
halungge Oct 24, 2023
8419e58
fix theta_v and exner calculation on halo cells in nh_solve, remove u…
halungge Oct 25, 2023
80d99d7
pre-commit
halungge Oct 25, 2023
3dd6416
- update data urls
halungge Oct 26, 2023
0477384
remove obsolete parametrization on processor_props fixture
halungge Oct 26, 2023
6f8d3a0
fix licence header
halungge Oct 26, 2023
4dca4ad
update path of gtfn backends in gt4py
halungge Oct 26, 2023
1d07373
add TODO for hydro_corr_horizontal
halungge Nov 1, 2023
15c2e11
Merge branch 'main' into halo_exchanges_nh_solve
halungge Nov 1, 2023
45009eb
fix clean up of MPI
halungge Nov 6, 2023
b1b7280
Merge branch 'main' into halo_exchanges_nh_solve
halungge Nov 7, 2023
b867b57
pre-commit fix
halungge Nov 7, 2023
c91cfaa
Merge branch 'main' into halo_exchanges_nh_solve
halungge Nov 8, 2023
8606b82
merge main
halungge Nov 8, 2023
f5ac80f
pre-commit
halungge Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,12 @@ def _sync_cell_fields(self, prognostic_state):
IF ( .NOT. lhdiff_rcf .OR. linit .OR. (iforcing /= inwp .AND. iforcing /= iaes) ) THEN
"""
log.debug("communication of prognostic cell fields: theta, w, exner - start")
handle_cell_comm = self._exchange.exchange(
self._exchange.exchange_and_wait(
CellDim,
prognostic_state.w,
prognostic_state.theta_v,
prognostic_state.exner,
)
handle_cell_comm.wait()
log.debug("communication of prognostic cell fields: theta, w, exner - done")

def _do_diffusion_step(
Expand Down Expand Up @@ -618,8 +617,7 @@ def _do_diffusion_step(

# 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert
log.debug("communication rbf extrapolation of vn - start")
h = self._exchange.exchange(VertexDim, self.u_vert, self.v_vert)
h.wait()
h = self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
Expand Down Expand Up @@ -676,8 +674,7 @@ def _do_diffusion_step(

if self.config.type_vn_diffu > 1:
log.debug("communication rbf extrapolation of z_nable2_e - start")
h_z = self._exchange.exchange(EdgeDim, self.z_nabla2_e)
h_z.wait()
self._exchange.exchange_and_wait(EdgeDim, self.z_nabla2_e)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
Expand All @@ -697,8 +694,7 @@ def _do_diffusion_step(

# 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields)
log.debug("communication rbf extrapolation of z_nable2_e - start")
h = self._exchange.exchange(VertexDim, self.u_vert, self.v_vert)
h.wait()
self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
@pytest.mark.mpi
@pytest.mark.parametrize("ndyn_substeps", [2])
@pytest.mark.parametrize("linit", [True, False])
@pytest.mark.parametrize("processor_props", [True], indirect=True)
def test_parallel_diffusion(
r04b09_diffusion_config,
step_date_init,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, exp, int32, log, where
from gt4py.next.ffront.fbuiltins import Field, exp, log, where

from icon4py.model.common.dimension import CellDim, KDim

Expand All @@ -40,10 +40,6 @@ def mo_solve_nonhydro_stencil_66(
exner: Field[[CellDim, KDim], float],
rd_o_cvd: float,
rd_o_p0ref: float,
horizontal_start: int32,
horizontal_end: int32,
vertical_start: int32,
vertical_end: int32,
):
_mo_solve_nonhydro_stencil_66(
bdy_halo_c,
Expand All @@ -53,8 +49,4 @@ def mo_solve_nonhydro_stencil_66(
rd_o_cvd,
rd_o_p0ref,
out=(theta_v, exner),
domain={
CellDim: (horizontal_start, horizontal_end),
KDim: (vertical_start, vertical_end),
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
)
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
from icon4py.model.common.decomposition.definitions import ExchangeRuntime, SingleNodeExchange
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex
from icon4py.model.common.grid.icon_grid import IconGrid
Expand Down Expand Up @@ -289,7 +290,8 @@ def __init__(self, config: NonHydrostaticConfig):


class SolveNonhydro:
def __init__(self):
def __init__(self, exchange: ExchangeRuntime = SingleNodeExchange()):
self._exchange = exchange
self._initialized = False
self.grid: Optional[IconGrid] = None
self.config: Optional[NonHydrostaticConfig] = None
Expand Down Expand Up @@ -493,18 +495,14 @@ def time_step(
lprep_adv=lprep_adv,
)

start_cell_local_minus1 = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.local(CellDim) - 1
)
end_cell_local = self.grid.get_end_index(CellDim, HorizontalMarkerIndex.local(CellDim))

start_cell_lb = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim)
)
end_cell_nudging_minus1 = self.grid.get_end_index(
CellDim, HorizontalMarkerIndex.nudging(CellDim) - 1
)

start_cell_halo = self.grid.get_start_index(CellDim, HorizontalMarkerIndex.halo(CellDim))
end_cell_end = self.grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim))
if self.grid.limited_area():
mo_solve_nonhydro_stencil_66.with_backend(run_gtfn)(
bdy_halo_c=self.metric_state_nonhydro.bdy_halo_c,
Expand All @@ -513,10 +511,6 @@ def time_step(
exner=prognostic_state_ls[nnew].exner,
rd_o_cvd=self.params.rd_o_cvd,
rd_o_p0ref=self.params.rd_o_p0ref,
horizontal_start=start_cell_local_minus1,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.n_lev(),
offset_provider={},
)

Expand All @@ -542,8 +536,8 @@ def time_step(
rho_new=prognostic_state_ls[nnew].rho,
theta_v_new=prognostic_state_ls[nnew].theta_v,
cvd_o_rd=self.params.cvd_o_rd,
horizontal_start=start_cell_local_minus1,
horizontal_end=end_cell_local,
horizontal_start=start_cell_halo,
horizontal_end=end_cell_end,
vertical_start=0,
vertical_end=self.grid.n_lev(),
offset_provider={},
Expand All @@ -562,6 +556,7 @@ def run_predictor_step(
nnow: int,
nnew: int,
):
# TODO (magdalena) fix this! when fixing call of velocity advection in predictor, condition is broken,
if l_init or l_recompute:
if self.config.itime_scheme == 4 and not l_init:
lvn_only = True # Recompute only vn tendency
Expand Down Expand Up @@ -973,16 +968,17 @@ def run_predictor_step(
"Koff": KDim,
},
)
# TODO (magdalena) what is 64?
self.z_hydro_corr_horizontal = np_as_located_field(EdgeDim)(
np.asarray(self.z_hydro_corr)[:, 64]
# TODO (Nikki) check when merging fused stencil
lowest_level = self.grid.n_lev() - 1
hydro_corr_horizontal = np_as_located_field(EdgeDim)(
np.asarray(self.z_hydro_corr)[:, lowest_level]
)

if self.config.igradp_method == 3:
mo_solve_nonhydro_stencil_22.with_backend(run_gtfn)(
ipeidx_dsl=self.metric_state_nonhydro.ipeidx_dsl,
pg_exdist=self.metric_state_nonhydro.pg_exdist,
z_hydro_corr=self.z_hydro_corr_horizontal,
z_hydro_corr=hydro_corr_horizontal,
z_gradh_exner=z_fields.z_gradh_exner,
horizontal_start=start_edge_nudging_plus1,
horizontal_end=end_edge_end,
Expand Down Expand Up @@ -1032,7 +1028,7 @@ def run_predictor_step(
offset_provider={},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(EdgeDim, prognostic_state[nnew].vn, z_fields.z_rho_e)
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved

mo_solve_nonhydro_stencil_30.with_backend(run_gtfn)(
e_flx_avg=self.interpolation_state.e_flx_avg,
Expand Down Expand Up @@ -1348,7 +1344,9 @@ def run_predictor_step(
offset_provider={"Koff": KDim},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(CellDim, prognostic_state[nnew].w, z_fields.z_dwdz_dd)
else:
self._exchange.exchange_and_wait(CellDim, prognostic_state[nnew].w)

def run_corrector_step(
self,
Expand Down Expand Up @@ -1562,8 +1560,7 @@ def run_corrector_step(
offset_provider={},
)

# COMMUNICATION PHASE

self._exchange.exchange_and_wait(EdgeDim, (prognostic_state[nnew].vn))
mo_solve_nonhydro_stencil_31.with_backend(run_gtfn)(
e_flx_avg=self.interpolation_state.e_flx_avg,
vn=prognostic_state[nnew].vn,
Expand Down Expand Up @@ -1872,4 +1869,9 @@ def run_corrector_step(
offset_provider={},
)

# COMMUNICATION PHASE
self._exchange.exchange_and_wait(
CellDim,
prognostic_state[nnew].rho,
prognostic_state[nnew].exner,
prognostic_state[nnew].w,
)
6 changes: 4 additions & 2 deletions model/atmosphere/dycore/tests/dycore_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
grid_savepoint,
icon_grid,
interpolation_savepoint,
istep,
jstep,
istep_exit,
istep_init,
jstep_exit,
jstep_init,
linit,
metrics_savepoint,
processor_props,
Expand Down
Loading