Skip to content

Commit

Permalink
Uncombine everything needed to compute vn
Browse files Browse the repository at this point in the history
  • Loading branch information
muellch committed Dec 4, 2024
1 parent ab41fa4 commit 240f985
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,34 @@
from icon4py.model.atmosphere.dycore.stencils.init_cell_kdim_field_with_zero_wp import (
init_cell_kdim_field_with_zero_wp,
)
from icon4py.model.atmosphere.dycore.stencils.init_cell_kdim_field_with_zero_vp import (
init_cell_kdim_field_with_zero_vp,
)

from icon4py.model.atmosphere.dycore.stencils.accumulate_prep_adv_fields import (
accumulate_prep_adv_fields,
)
from icon4py.model.atmosphere.dycore.stencils.add_analysis_increments_from_data_assimilation import (
add_analysis_increments_from_data_assimilation,
)
from icon4py.model.atmosphere.dycore.stencils.set_theta_v_prime_ic_at_lower_boundary import (
set_theta_v_prime_ic_at_lower_boundary,
)
from icon4py.model.atmosphere.dycore.stencils.extrapolate_temporally_exner_pressure import (
extrapolate_temporally_exner_pressure,
)
from icon4py.model.atmosphere.dycore.stencils.add_analysis_increments_to_vn import (
add_analysis_increments_to_vn,
)
from icon4py.model.atmosphere.dycore.stencils.compute_first_vertical_derivative import (
compute_first_vertical_derivative,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_half_levels_vp import (
interpolate_to_half_levels_vp,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_surface import (
interpolate_to_surface,
)
from icon4py.model.atmosphere.dycore.stencils.add_temporal_tendencies_to_vn import (
add_temporal_tendencies_to_vn,
)
Expand Down Expand Up @@ -471,6 +489,15 @@ def __init__(
self._compute_theta_and_exner = compute_theta_and_exner.with_backend(self._backend)
self._compute_exner_from_rhotheta = compute_exner_from_rhotheta.with_backend(self._backend)
self._update_theta_v = update_theta_v.with_backend(self._backend)
self._compute_first_vertical_derivative = (
compute_first_vertical_derivative.with_backend(self._backend)
)
self._interpolate_to_half_levels_vp = (
interpolate_to_half_levels_vp.with_backend(self._backend)
)
self._interpolate_to_surface = (
interpolate_to_surface.with_backend(self._backend)
)
self._init_two_cell_kdim_fields_with_zero_vp = (
init_two_cell_kdim_fields_with_zero_vp.with_backend(self._backend)
)
Expand Down Expand Up @@ -504,6 +531,12 @@ def __init__(
self._backend
)
)
self._extrapolate_temporally_exner_pressure = (
extrapolate_temporally_exner_pressure.with_backend(self._backend)
)
self._set_theta_v_prime_ic_at_lower_boundary = (
set_theta_v_prime_ic_at_lower_boundary.with_backend(self._backend)
)
self._compute_hydrostatic_correction_term = (
compute_hydrostatic_correction_term.with_backend(self._backend)
)
Expand Down Expand Up @@ -582,6 +615,9 @@ def __init__(
self._init_cell_kdim_field_with_zero_wp = init_cell_kdim_field_with_zero_wp.with_backend(
self._backend
)
self._init_cell_kdim_field_with_zero_vp = init_cell_kdim_field_with_zero_vp.with_backend(
self._backend
)
self._update_mass_flux_weighted = update_mass_flux_weighted.with_backend(self._backend)
self._compute_z_raylfac = dycore_utils.compute_z_raylfac.with_backend(self._backend)
self._predictor_stencils_2_3 = nhsolve_stencils.predictor_stencils_2_3.with_backend(
Expand Down Expand Up @@ -994,7 +1030,20 @@ def run_predictor_step(
# - $\exnerprime{\n}{\c}{\k}$ : exner - exner_ref_mc
# - $\exnerprime{\n-1}{\c}{\k}$ : exner_pr
#
self._predictor_stencils_2_3(
#self._predictor_stencils_2_3(
# exner_exfac=self._metric_state_nonhydro.exner_exfac,
# exner=prognostic_state[nnow].exner,
# exner_ref_mc=self._metric_state_nonhydro.exner_ref_mc,
# exner_pr=diagnostic_state_nh.exner_pr,
# z_exner_ex_pr=self.z_exner_ex_pr,
# horizontal_start=self._start_cell_lateral_boundary_level_3,
# horizontal_end=self._end_cell_halo,
# vertical_start=0,
# vertical_end=self._grid.num_levels + 1,
# offset_provider={},
#)

self._extrapolate_temporally_exner_pressure(
exner_exfac=self._metric_state_nonhydro.exner_exfac,
exner=prognostic_state[nnow].exner,
exner_ref_mc=self._metric_state_nonhydro.exner_ref_mc,
Expand All @@ -1003,6 +1052,15 @@ def run_predictor_step(
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider={},
)

self._init_cell_kdim_field_with_zero_wp(
field_with_zero_wp=self.z_exner_ex_pr,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=self._grid.num_levels,
vertical_end=self._grid.num_levels + 1,
offset_provider={},
)
Expand Down Expand Up @@ -1033,17 +1091,50 @@ def run_predictor_step(
# - $\exnerprime{\ntilde}{\c}{\k}$ : z_exner_ex_pr
# - $1 / \Dz{\k}$ : inv_ddqz_z_full
#
self._predictor_stencils_4_5_6(
wgtfacq_c_dsl=self._metric_state_nonhydro.wgtfacq_c,
z_exner_ex_pr=self.z_exner_ex_pr,
z_exner_ic=self.z_exner_ic,
#self._predictor_stencils_4_5_6(
# wgtfacq_c_dsl=self._metric_state_nonhydro.wgtfacq_c,
# z_exner_ex_pr=self.z_exner_ex_pr,
# z_exner_ic=self.z_exner_ic,
# wgtfac_c=self._metric_state_nonhydro.wgtfac_c,
# inv_ddqz_z_full=self._metric_state_nonhydro.inv_ddqz_z_full,
# z_dexner_dz_c_1=self.z_dexner_dz_c_1,
# horizontal_start=self._start_cell_lateral_boundary_level_3,
# horizontal_end=self._end_cell_halo,
# vertical_start=max(1, self._vertical_params.nflatlev),
# vertical_end=self._grid.num_levels + 1,
# offset_provider=self._grid.offset_providers,
#)

self._interpolate_to_surface(
wgtfacq_c=self._metric_state_nonhydro.wgtfacq_c,
interpolant=self.z_exner_ex_pr,
interpolation_to_surface=self.z_exner_ic,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=self._grid.num_levels,
vertical_end=self._grid.num_levels + 1,
offset_provider=self._grid.offset_providers,
)

self._interpolate_to_half_levels_vp(
wgtfac_c=self._metric_state_nonhydro.wgtfac_c,
interpolant=self.z_exner_ex_pr,
interpolation_to_half_levels_vp=self.z_exner_ic,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=max(1, self._vertical_params.nflatlev),
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)

self._compute_first_vertical_derivative(
z_exner_ic=self.z_exner_ic,
inv_ddqz_z_full=self._metric_state_nonhydro.inv_ddqz_z_full,
z_dexner_dz_c_1=self.z_dexner_dz_c_1,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=max(1, self._vertical_params.nflatlev),
vertical_end=self._grid.num_levels + 1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)

Expand Down Expand Up @@ -1076,17 +1167,24 @@ def run_predictor_step(
)

# Perturbation theta at top and surface levels
self._predictor_stencils_11_lower_upper(
wgtfacq_c_dsl=self._metric_state_nonhydro.wgtfacq_c,
self._init_cell_kdim_field_with_zero_vp(
field_with_zero_vp=self.z_theta_v_pr_ic,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=0,
vertical_end=1,
offset_provider=self._grid.offset_providers,
)

self._set_theta_v_prime_ic_at_lower_boundary(
wgtfacq_c=self._metric_state_nonhydro.wgtfacq_c,
z_rth_pr=self.z_rth_pr_2,
theta_ref_ic=self._metric_state_nonhydro.theta_ref_ic,
z_theta_v_pr_ic=self.z_theta_v_pr_ic,
theta_v_ic=diagnostic_state_nh.theta_v_ic,
k_field=self.k_field,
nlev=self._grid.num_levels,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=0,
vertical_start=self._grid.num_levels,
vertical_end=self._grid.num_levels + 1,
offset_provider=self._grid.offset_providers,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@
from icon4py.model.atmosphere.dycore.stencils.copy_cell_kdim_field_to_vp import (
copy_cell_kdim_field_to_vp,
)
from icon4py.model.atmosphere.dycore.stencils.compute_advective_vertical_wind_tendency import (
compute_advective_vertical_wind_tendency,
)
from icon4py.model.atmosphere.dycore.stencils.compute_contravariant_correction import (
compute_contravariant_correction,
)
from icon4py.model.atmosphere.dycore.stencils.add_interpolated_horizontal_advection_of_w import (
add_interpolated_horizontal_advection_of_w,
)
from icon4py.model.atmosphere.dycore.stencils.compute_horizontal_kinetic_energy import (
compute_horizontal_kinetic_energy,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_half_levels_vp import (
interpolate_to_half_levels_vp,
)
from icon4py.model.atmosphere.dycore.stencils.init_cell_kdim_field_with_zero_vp import (
init_cell_kdim_field_with_zero_vp,
)
Expand Down Expand Up @@ -102,6 +111,9 @@ def __init__(
self._interpolate_vn_to_ie_and_compute_ekin_on_edges = (
interpolate_vn_to_ie_and_compute_ekin_on_edges.with_backend(self._backend)
)
self._interpolate_to_half_levels_vp = interpolate_to_half_levels_vp.with_backend(
self._backend
)
self._interpolate_vt_to_interface_edges = interpolate_vt_to_interface_edges.with_backend(
self._backend
)
Expand All @@ -111,9 +123,15 @@ def __init__(
compute_horizontal_advection_term_for_vertical_velocity.with_backend(self._backend)
)
self._interpolate_to_cell_center = interpolate_to_cell_center.with_backend(self._backend)
self._compute_advective_vertical_wind_tendency = (
compute_advective_vertical_wind_tendency.with_backend(self._backend)
)
self._fused_stencils_9_10 = velocity_stencils.fused_stencils_9_10.with_backend(
self._backend
)
self._add_interpolated_horizontal_advection_of_w = (
add_interpolated_horizontal_advection_of_w.with_backend(self._backend)
)
self._fused_stencils_11_to_13 = velocity_stencils.fused_stencils_11_to_13.with_backend(
self._backend
)
Expand Down Expand Up @@ -356,17 +374,17 @@ def run_predictor_step(
)

self._interpolate_to_cell_center(
z_w_concorr_me=z_w_concorr_me,
interpolant=z_w_concorr_me,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
z_w_concorr_mc=self.z_w_concorr_mc,
interpolation=self.z_w_concorr_mc,
horizontal_start=self._start_cell_lateral_boundary_level_4,
horizontal_end=self._end_cell_halo,
vertical_start=self.vertical_params.nflatlev,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)

self.interpolate_to_half_levels_vp(
self._interpolate_to_half_levels_vp(
interpolant=self.z_w_concorr_mc,
wgtfac_c=self.metric_state.wgtfac_c,
interpolation_to_half_levels_vp=diagnostic_state.w_concorr_c,
Expand Down Expand Up @@ -435,11 +453,9 @@ def run_predictor_step(
)

if not vn_only:
self._fused_stencils_16_to_17(
self._compute_advective_vertical_wind_tendency(
z_w_con_c=self.z_w_con_c,
w=prognostic_state.w,
local_z_v_grad_w=self.z_v_grad_w,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
local_z_w_con_c=self.z_w_con_c,
coeff1_dwdz=self.metric_state.coeff1_dwdz,
coeff2_dwdz=self.metric_state.coeff2_dwdz,
ddt_w_adv=diagnostic_state.ddt_w_adv_pc[ntnd],
Expand All @@ -450,6 +466,17 @@ def run_predictor_step(
offset_provider=self.grid.offset_providers,
)

self._add_interpolated_horizontal_advection_of_w(
e_bln_c_s=self.interpolation_state.e_bln_c_s,
z_v_grad_w=self.z_v_grad_w,
ddt_w_adv=diagnostic_state.ddt_w_adv_pc[ntnd],
horizontal_start=self._start_cell_nudging,
horizontal_end=self._end_cell_local,
vertical_start=1,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)

self._add_extra_diffusion_for_w_con_approaching_cfl(
levmask=self.levmask,
cfl_clipping=self.cfl_clipping,
Expand Down Expand Up @@ -591,20 +618,35 @@ def run_corrector_step(
offset_provider=self.grid.offset_providers,
)

self._fused_stencils_11_to_13(
w=prognostic_state.w,
w_concorr_c=diagnostic_state.w_concorr_c,
local_z_w_con_c=self.z_w_con_c,
k_field=self.k_field,
nflatlev_startindex=self.vertical_params.nflatlev,
nlev=self.grid.num_levels,
self._copy_cell_kdim_field_to_vp(
field=prognostic_state.w,
field_copy=self.z_w_con_c,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

self._init_cell_kdim_field_with_zero_vp(
field_with_zero_vp=self.z_w_con_c,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=self.grid.num_levels,
vertical_end=self.grid.num_levels + 1,
offset_provider={},
)

self._correct_contravariant_vertical_velocity(
z_w_con_c=self.z_w_con_c,
w_concorr_c=diagnostic_state.w_concorr_c,
horizontal_start=self._start_cell_lateral_boundary_level_3,
horizontal_end=self._end_cell_halo,
vertical_start=self.vertical_params.nflatlev + 1,
vertical_end=self.grid.num_levels,
offset_provider={},
)

self._fused_stencil_14(
ddqz_z_half=self.metric_state.ddqz_z_half,
local_z_w_con_c=self.z_w_con_c,
Expand All @@ -631,11 +673,9 @@ def run_corrector_step(
offset_provider=self.grid.offset_providers,
)

self._fused_stencils_16_to_17(
self._compute_advective_vertical_wind_tendency(
z_w_con_c=self.z_w_con_c,
w=prognostic_state.w,
local_z_v_grad_w=self.z_v_grad_w,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
local_z_w_con_c=self.z_w_con_c,
coeff1_dwdz=self.metric_state.coeff1_dwdz,
coeff2_dwdz=self.metric_state.coeff2_dwdz,
ddt_w_adv=diagnostic_state.ddt_w_adv_pc[ntnd],
Expand All @@ -646,6 +686,17 @@ def run_corrector_step(
offset_provider=self.grid.offset_providers,
)

self._add_interpolated_horizontal_advection_of_w(
e_bln_c_s=self.interpolation_state.e_bln_c_s,
z_v_grad_w=self.z_v_grad_w,
ddt_w_adv=diagnostic_state.ddt_w_adv_pc[ntnd],
horizontal_start=self._start_cell_nudging,
horizontal_end=self._end_cell_local,
vertical_start=1,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)

self._add_extra_diffusion_for_w_con_approaching_cfl(
levmask=self.levmask,
cfl_clipping=self.cfl_clipping,
Expand Down

0 comments on commit 240f985

Please sign in to comment.