Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup solve nonhydro interface #333

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -990,10 +990,11 @@ def stencils_47_48_49(
ddt_exner_phy: Field[[CellDim, KDim], float],
k_field: Field[[KDim], int32],
dtime: float,
cell_startindex_nudging_plus1: int32,
cell_endindex_interior: int32,
nlev: int32,
nlev_k: int32,
horizontal_start: int32,
horizontal_end: int32,
vertical_start: int32,
vertical_end: int32,
):
_stencils_47_48_49(
w_nnew,
Expand All @@ -1014,8 +1015,8 @@ def stencils_47_48_49(
nlev,
out=(w_nnew, z_contr_w_fl_l, z_rho_expl, z_exner_expl),
domain={
CellDim: (cell_startindex_nudging_plus1, cell_endindex_interior),
KDim: (0, nlev_k),
CellDim: (horizontal_start, horizontal_end),
KDim: (vertical_start, vertical_end),
},
)

Expand Down Expand Up @@ -1043,10 +1044,6 @@ def predictor_stencils_59_60(
ddt_exner_phy: Field[[CellDim, KDim], float],
ndyn_substeps_var: float,
dtime: float,
cell_startindex_nudging_plus1: int32,
cell_endindex_interior: int32,
kstart_moist: int32,
nlev: int32,
horizontal_start: int32,
horizontal_end: int32,
vertical_start: int32,
Expand All @@ -1060,8 +1057,8 @@ def predictor_stencils_59_60(
dtime,
out=exner_dyn_incr,
domain={
CellDim: (cell_startindex_nudging_plus1, cell_endindex_interior),
KDim: (kstart_moist, nlev),
CellDim: (horizontal_start, horizontal_end),
KDim: (vertical_start, vertical_end),
},
)

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from gt4py.next.common import Field

from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common.grid.base import BaseGrid


@dataclass
Expand All @@ -40,3 +42,32 @@ class ZFields:
z_graddiv_vn: Field[[EdgeDim, KDim], float]
z_rho_expl: Field[[CellDim, KDim], float]
z_dwdz_dd: Field[[CellDim, KDim], float]


def _allocate_z_fields(grid: BaseGrid):
return ZFields(
z_gradh_exner=_allocate(EdgeDim, KDim, grid=grid),
z_alpha=_allocate(CellDim, KDim, is_halfdim=True, grid=grid),
z_beta=_allocate(
CellDim, KDim, is_halfdim=True, grid=grid
), # TODO (@halungge) overallocated with fake halfdim
z_w_expl=_allocate(CellDim, KDim, is_halfdim=True, grid=grid),
z_exner_expl=_allocate(CellDim, KDim, is_halfdim=True, grid=grid),
z_q=_allocate(
CellDim, KDim, is_halfdim=True, grid=grid
), # TODO (@halungge) overallocated with fake halfdim
z_contr_w_fl_l=_allocate(CellDim, KDim, is_halfdim=True, grid=grid),
z_rho_e=_allocate(EdgeDim, KDim, grid=grid),
z_theta_v_e=_allocate(EdgeDim, KDim, grid=grid),
z_graddiv_vn=_allocate(EdgeDim, KDim, grid=grid),
z_rho_expl=_allocate(
CellDim, KDim, is_halfdim=True, grid=grid
), # TODO (@halungge) overallocated with fake halfdim
z_dwdz_dd=_allocate(CellDim, KDim, grid=grid),
z_kin_hor_e=_allocate(
EdgeDim, KDim, is_halfdim=True, grid=grid
), # TODO (@halungge) overallocated with fake halfdim
z_vt_ie=_allocate(
EdgeDim, KDim, is_halfdim=True, grid=grid
), # TODO (@halungge) overallocated with fake halfdim
)
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _allocate_local_fields(self):
self.zeta = _allocate(VertexDim, KDim, grid=self.grid)
self.z_w_con_c_full = _allocate(CellDim, KDim, grid=self.grid)
self.cfl_clipping = _allocate(CellDim, KDim, grid=self.grid, dtype=bool)
self.levmask = _allocate(KDim, grid=self.grid, dtype=bool)
self.level_mask = _allocate(KDim, grid=self.grid, dtype=bool)
self.vcfl_dsl = _allocate(CellDim, KDim, grid=self.grid)
self.k_field = _allocate_indices(KDim, grid=self.grid, is_halfdim=True)

Expand Down Expand Up @@ -376,7 +376,7 @@ def run_predictor_step(
)

mo_velocity_advection_stencil_18.with_backend(run_gtfn)(
levmask=self.levmask,
levmask=self.level_mask,
cfl_clipping=self.cfl_clipping,
owner_mask=self.c_owner_mask,
z_w_con_c=self.z_w_con_c,
Expand All @@ -397,8 +397,6 @@ def run_predictor_step(
},
)

self.levelmask = self.levmask

mo_velocity_advection_stencil_19.with_backend(backend)(
z_kin_hor_e=z_kin_hor_e,
coeff_gradekin=self.metric_state.coeff_gradekin,
Expand All @@ -424,7 +422,7 @@ def run_predictor_step(
)

mo_velocity_advection_stencil_20.with_backend(backend)(
levelmask=self.levelmask,
levelmask=self.level_mask,
c_lin_e=self.interpolation_state.c_lin_e,
z_w_con_c_full=self.z_w_con_c_full,
ddqz_z_full_e=self.metric_state.ddqz_z_full_e,
Expand All @@ -451,7 +449,7 @@ def run_predictor_step(
)

def _update_levmask_from_cfl_clipping(self):
self.levmask = as_field(
self.level_mask = as_field(
domain=(KDim,), data=(np.any(self.cfl_clipping.asnumpy(), 0)), dtype=bool
)

Expand Down Expand Up @@ -623,7 +621,7 @@ def run_corrector_step(
)

mo_velocity_advection_stencil_18.with_backend(backend)(
levmask=self.levmask,
levmask=self.level_mask,
cfl_clipping=self.cfl_clipping,
owner_mask=self.c_owner_mask,
z_w_con_c=self.z_w_con_c,
Expand All @@ -644,9 +642,6 @@ def run_corrector_step(
},
)

# This behaviour needs to change for multiple blocks
self.levelmask = self.levmask

mo_velocity_advection_stencil_19.with_backend(backend)(
z_kin_hor_e=z_kin_hor_e,
coeff_gradekin=self.metric_state.coeff_gradekin,
Expand All @@ -672,7 +667,7 @@ def run_corrector_step(
)

mo_velocity_advection_stencil_20.with_backend(backend)(
levelmask=self.levelmask,
levelmask=self.level_mask,
c_lin_e=self.interpolation_state.c_lin_e,
z_w_con_c_full=self.z_w_con_c_full,
ddqz_z_full_e=self.metric_state.ddqz_z_full_e,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@
NonHydrostaticParams,
SolveNonhydro,
)
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
from icon4py.model.atmosphere.dycore.state_utils.states import (
DiagnosticStateNonHydro,
PrepAdvection,
)
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.dimension import CellDim, EdgeDim, VertexDim
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
Expand Down Expand Up @@ -157,30 +154,6 @@ def test_run_solve_nonhydro_single_step(
exner=sp.exner_new(),
)

z_fields = ZFields(
z_gradh_exner=_allocate(EdgeDim, KDim, grid=icon_grid),
z_alpha=_allocate(CellDim, KDim, is_halfdim=True, grid=icon_grid),
z_beta=_allocate(CellDim, KDim, grid=icon_grid),
z_w_expl=_allocate(CellDim, KDim, is_halfdim=True, grid=icon_grid),
z_exner_expl=_allocate(CellDim, KDim, grid=icon_grid),
z_q=_allocate(CellDim, KDim, grid=icon_grid),
z_contr_w_fl_l=_allocate(CellDim, KDim, is_halfdim=True, grid=icon_grid),
z_rho_e=_allocate(EdgeDim, KDim, grid=icon_grid),
z_theta_v_e=_allocate(EdgeDim, KDim, grid=icon_grid),
z_graddiv_vn=_allocate(EdgeDim, KDim, grid=icon_grid),
z_rho_expl=_allocate(CellDim, KDim, grid=icon_grid),
z_dwdz_dd=_allocate(CellDim, KDim, grid=icon_grid),
z_kin_hor_e=_allocate(EdgeDim, KDim, grid=icon_grid),
z_vt_ie=_allocate(EdgeDim, KDim, grid=icon_grid),
)

nh_constants = NHConstants(
wgt_nnow_rth=sp.wgt_nnow_rth(),
wgt_nnew_rth=sp.wgt_nnew_rth(),
wgt_nnow_vel=sp.wgt_nnow_vel(),
wgt_nnew_vel=sp.wgt_nnew_vel(),
)

interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)

Expand Down Expand Up @@ -210,8 +183,6 @@ def test_run_solve_nonhydro_single_step(
diagnostic_state_nh=diagnostic_state_nh,
prognostic_state_ls=prognostic_state_ls,
prep_adv=prep_adv,
z_fields=z_fields,
nh_constants=nh_constants,
divdamp_fac_o2=0.032,
dtime=dtime,
idyn_timestep=dyn_timestep,
Expand Down
Loading