Skip to content

Commit

Permalink
uncombine some more
Browse files Browse the repository at this point in the history
  • Loading branch information
muellch committed Dec 1, 2024
1 parent 8157a90 commit 3ec3e8f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from icon4py.model.atmosphere.dycore.copy_cell_kdim_field_to_vp import (
copy_cell_kdim_field_to_vp,
)
from icon4py.model.atmosphere.dycore.compute_contravariant_correction import (
compute_contravariant_correction,
)
from icon4py.model.atmosphere.dycore.compute_horizontal_kinetic_energy import (
compute_horizontal_kinetic_energy,
)
from icon4py.model.atmosphere.dycore.init_cell_kdim_field_with_zero_vp import (
init_cell_kdim_field_with_zero_vp,
)
Expand Down Expand Up @@ -121,6 +127,12 @@ def __init__(
self._copy_cell_kdim_field_to_vp = (
copy_cell_kdim_field_to_vp.with_backend(self._backend)
)
self._compute_contravariant_correction = (
compute_contravariant_correction.with_backend(self._backend)
)
self._compute_horizontal_kinetic_energy = (
compute_horizontal_kinetic_energy.with_backend(self._backend)
)
self._init_cell_kdim_field_with_zero_vp = (
init_cell_kdim_field_with_zero_vp.with_backend(self._backend)
)
Expand Down Expand Up @@ -279,24 +291,32 @@ def run_predictor_step(
offset_provider=self.grid.offset_providers,
)

self._fused_stencils_4_5(
self._compute_contravariant_correction(
vn=prognostic_state.vn,
ddxn_z_full=self.metric_state.ddxn_z_full,
ddxt_z_full=self.metric_state.ddxt_z_full,
vt=diagnostic_state.vt,
z_w_concorr_me=z_w_concorr_me,
horizontal_start=self._start_edge_lateral_boundary_level_5,
horizontal_end=self._end_edge_halo_level_2,
vertical_start=self.vertical_params.nflatlev,
vertical_end=self.grid.num_levels,
offset_provider={},
)

self._compute_horizontal_kinetic_energy(
vn=prognostic_state.vn,
vt=diagnostic_state.vt,
vn_ie=diagnostic_state.vn_ie,
z_vt_ie=z_vt_ie,
z_kin_hor_e=z_kin_hor_e,
ddxn_z_full=self.metric_state.ddxn_z_full,
ddxt_z_full=self.metric_state.ddxt_z_full,
z_w_concorr_me=z_w_concorr_me,
k_field=self.k_field,
nflatlev_startindex=self.vertical_params.nflatlev,
nlev=self.grid.num_levels,
horizontal_start=self._start_edge_lateral_boundary_level_5,
horizontal_end=self._end_edge_halo_level_2,
vertical_start=0,
vertical_end=self.grid.num_levels,
vertical_end=1,
offset_provider={},
)

self._extrapolate_at_top(
wgtfacq_e=self.metric_state.wgtfacq_e,
vn=prognostic_state.vn,
Expand Down Expand Up @@ -336,18 +356,24 @@ def run_predictor_step(
offset_provider=self.grid.offset_providers,
)

self._fused_stencils_9_10(
self._interpolate_to_cell_center(
z_w_concorr_me=z_w_concorr_me,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
local_z_w_concorr_mc=self.z_w_concorr_mc,
z_w_concorr_mc=self.z_w_concorr_mc,
horizontal_start=self._start_cell_lateral_boundary_level_4,
horizontal_end=self._end_cell_halo,
vertical_start=self.vertical_params.nflatlev,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)

self.interpolate_to_half_levels_vp(
interpolant=self.z_w_concorr_mc,
wgtfac_c=self.metric_state.wgtfac_c,
w_concorr_c=diagnostic_state.w_concorr_c,
k_field=self.k_field,
nflatlev_startindex=self.vertical_params.nflatlev,
nlev=self.grid.num_levels,
interpolation_to_half_levels_vp=diagnostic_state.w_concorr_c,
horizontal_start=self._start_cell_lateral_boundary_level_4,
horizontal_end=self._end_cell_halo,
vertical_start=0,
vertical_start=self.vertical_params.nflatlev + 1,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,26 @@ def fused_stencils_4_5(
vertical_start: gtx.int32,
vertical_end: gtx.int32,
):
_compute_contravariant_correction(
_fused_stencils_4_5(
vn,
vt,
vn_ie,
z_vt_ie,
z_kin_hor_e,
ddxn_z_full,
ddxt_z_full,
vt,
out=z_w_concorr_me,
domain = {
dims.EdgeDim: (horizontal_start, horizontal_end),
dims.KDim: (nflatlev_startindex, vertical_end),
},
)
_compute_horizontal_kinetic_energy(
vn,
vt,
out = (vn_ie, z_vt_ie, z_kin_hor_e),
domain = {
z_w_concorr_me,
k_field,
nflatlev_startindex,
nlev,
out=(z_w_concorr_me, vn_ie, z_vt_ie, z_kin_hor_e),
domain={
dims.EdgeDim: (horizontal_start, horizontal_end),
dims.KDim: (vertical_start, vertical_start + 1),
dims.KDim: (vertical_start, vertical_end),
},
)


@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED)
def extrapolate_at_top(
wgtfacq_e: fa.EdgeKField[float],
Expand Down Expand Up @@ -173,23 +172,19 @@ def fused_stencils_9_10(
vertical_start: gtx.int32,
vertical_end: gtx.int32,
):
_interpolate_to_cell_center(
_fused_stencils_9_10(
z_w_concorr_me,
e_bln_c_s,
out=local_z_w_concorr_mc,
domain = {
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (nflatlev_startindex, vertical_end),
},
)

_interpolate_to_half_levels_vp(
interpolant=local_z_w_concorr_mc,
wgtfac_c=wgtfac_c,
out=w_concorr_c,
domain = {
local_z_w_concorr_mc,
wgtfac_c,
w_concorr_c,
k_field,
nflatlev_startindex,
nlev,
out=(local_z_w_concorr_mc, w_concorr_c),
domain={
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (nflatlev_startindex + 1, vertical_end),
dims.KDim: (vertical_start, vertical_end),
},
)

Expand Down Expand Up @@ -232,32 +227,20 @@ def fused_stencils_11_to_13(
vertical_start: gtx.int32,
vertical_end: gtx.int32,
):
_copy_cell_kdim_field_to_vp(
_fused_stencils_11_to_13(
w,
w_concorr_c,
local_z_w_con_c,
k_field,
nflatlev_startindex,
nlev,
out=local_z_w_con_c,
domain={
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (vertical_start, vertical_end - 1),
},
)

_init_cell_kdim_field_with_zero_vp(
out=local_z_w_con_c,
domain={
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (vertical_end - 1, vertical_end),
dims.KDim: (vertical_start, vertical_end),
},
)

_correct_contravariant_vertical_velocity(
local_z_w_con_c,
w_concorr_c,
out=local_z_w_con_c,
domain = {
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (nflatlev_startindex + 1, vertical_end - 1),
},
)

@gtx.field_operator
def _fused_stencil_14(
Expand Down Expand Up @@ -350,4 +333,3 @@ def fused_stencils_16_to_17(
dims.KDim: (vertical_start, vertical_end),
},
)

0 comments on commit 3ec3e8f

Please sign in to comment.