Skip to content

Commit

Permalink
DaCe Orchestration for the Diffusion Granule (#514)
Browse files Browse the repository at this point in the history
Co-authored-by: egparedes
  • Loading branch information
kotsaloscv authored Oct 15, 2024
1 parent b38c748 commit 8836483
Show file tree
Hide file tree
Showing 16 changed files with 2,121 additions and 103 deletions.
15 changes: 15 additions & 0 deletions ci/dace.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ test_model_stencils_x86_64:
test_model_stencils_aarch64:
extends: [.test_model_stencils, .test_template_aarch64]

.test_model_datatests:
stage: test
script:
- pip install dace==$DACE_VERSION
- tox -r -e run_model_tests -c model/ -- --backend=$BACKEND $DACE_ORCHESTRATION $COMPONENT --verbose
parallel:
matrix:
- COMPONENT: [atmosphere/diffusion/tests/diffusion_tests]
BACKEND: [dace_cpu_noopt]
DACE_ORCHESTRATION: ['--dace-orchestration=True', '']
test_model_datatests_x86_64:
extends: [.test_model_datatests, .test_template_x86_64]
test_model_datatests_aarch64:
extends: [.test_model_datatests, .test_template_aarch64]

.benchmark_model_stencils:
stage: benchmark
script:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Final, Optional

import gt4py.next as gtx
from gt4py.next import int32

import icon4py.model.common.states.prognostic_state as prognostics
from gt4py.next import backend
Expand Down Expand Up @@ -64,6 +65,9 @@
from icon4py.model.common.settings import xp
from icon4py.model.common.utils import gt4py_field_allocation as field_alloc

from icon4py.model.common.orchestration import decorator as orchestration


"""
Diffusion module ported from ICON mo_nh_diffusion.f90.
Expand Down Expand Up @@ -356,6 +360,9 @@ def __init__(
):
self._backend = backend
self._exchange = exchange
self.halo_exchange_wait = decomposition.create_halo_exchange_wait(
self._exchange
) # wait on a communication handle
self._initialized = False
self.rd_o_cvd: float = constants.GAS_CONSTANT_DRY_AIR / (
constants.CPD - constants.GAS_CONSTANT_DRY_AIR
Expand Down Expand Up @@ -487,6 +494,10 @@ def init(

self._determine_horizontal_domains()

self.compile_time_connectivities = orchestration.build_compile_time_connectivities(
self.grid.offset_providers
)

self._initialized = True

@property
Expand Down Expand Up @@ -649,6 +660,7 @@ def _sync_cell_fields(self, prognostic_state):
)
log.debug("communication of prognostic cell fields: theta, w, exner - done")

@orchestration.orchestrate
def _do_diffusion_step(
self,
diagnostic_state: diffusion_states.DiffusionDiagnosticState,
Expand All @@ -670,12 +682,15 @@ def _do_diffusion_step(
smag_offset:
"""
num_levels = self.grid.num_levels
# dtime dependent: enh_smag_factor,
self.scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})
self.scale_k.with_connectivities(self.compile_time_connectivities)(
self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={}
)

log.debug("rbf interpolation 1: start")
self.mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex.with_connectivities(
self.compile_time_connectivities
)(
p_e_in=prognostic_state.vn,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -684,18 +699,20 @@ def _do_diffusion_step(
horizontal_start=self._vertex_start_lateral_boundary_level_2,
horizontal_end=self._vertex_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug("rbf interpolation 1: end")

# 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert
log.debug("communication rbf extrapolation of vn - start")
self._exchange.exchange_and_wait(dims.VertexDim, self.u_vert, self.v_vert)
self._exchange(self.u_vert, self.v_vert, dim=dims.VertexDim, wait=True)
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
self.calculate_nabla2_and_smag_coefficients_for_vn(
self.calculate_nabla2_and_smag_coefficients_for_vn.with_connectivities(
self.compile_time_connectivities
)(
diff_multfac_smag=self.diff_multfac_smag,
tangent_orientation=self.edge_params.tangent_orientation,
inv_primal_edge_length=self.edge_params.inverse_primal_edge_lengths,
Expand All @@ -715,7 +732,7 @@ def _do_diffusion_step(
horizontal_start=self._edge_start_lateral_boundary_level_5,
horizontal_end=self._edge_end_halo_level_2,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug("running stencil 01 (calculate_nabla2_and_smag_coefficients_for_vn): end")
Expand All @@ -726,7 +743,9 @@ def _do_diffusion_step(
log.debug(
"running stencils 02 03 (calculate_diagnostic_quantities_for_turbulence): start"
)
self.calculate_diagnostic_quantities_for_turbulence(
self.calculate_diagnostic_quantities_for_turbulence.with_connectivities(
self.compile_time_connectivities
)(
kh_smag_ec=self.kh_smag_ec,
vn=prognostic_state.vn,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
Expand All @@ -738,7 +757,7 @@ def _do_diffusion_step(
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=1,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug(
Expand All @@ -749,11 +768,13 @@ def _do_diffusion_step(
# TODO (magdalena) move this up and do asynchronous exchange
if self.config.type_vn_diffu > 1:
log.debug("communication rbf extrapolation of z_nable2_e - start")
self._exchange.exchange_and_wait(dims.EdgeDim, self.z_nabla2_e)
self._exchange(self.z_nabla2_e, dim=dims.EdgeDim, wait=True)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
self.mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex.with_connectivities(
self.compile_time_connectivities
)(
p_e_in=self.z_nabla2_e,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -762,18 +783,18 @@ def _do_diffusion_step(
horizontal_start=self._vertex_start_lateral_boundary_level_2,
horizontal_end=self._vertex_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug("2nd rbf interpolation: end")

# 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields)
log.debug("communication rbf extrapolation of z_nable2_e - start")
self._exchange.exchange_and_wait(dims.VertexDim, self.u_vert, self.v_vert)
self._exchange(self.u_vert, self.v_vert, dim=dims.VertexDim, wait=True)
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
self.apply_diffusion_to_vn(
self.apply_diffusion_to_vn.with_connectivities(self.compile_time_connectivities)(
u_vert=self.u_vert,
v_vert=self.v_vert,
primal_normal_vert_v1=self.edge_params.primal_normal_vert[0],
Expand All @@ -794,42 +815,49 @@ def _do_diffusion_step(
horizontal_start=self._edge_start_lateral_boundary_level_5,
horizontal_end=self._edge_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): end")

log.debug("communication of prognistic.vn : start")
handle_edge_comm = self._exchange.exchange(dims.EdgeDim, prognostic_state.vn)
handle_edge_comm = self._exchange(prognostic_state.vn, dim=dims.EdgeDim, wait=False)

log.debug(
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start"
)
# TODO (magdalena) get rid of this copying. So far passing an empty buffer instead did not verify?
self.copy_field(prognostic_state.w, self.w_tmp, offset_provider={})
self.copy_field.with_connectivities(self.compile_time_connectivities)(
prognostic_state.w, self.w_tmp, offset_provider={}
)

self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence.with_connectivities(
self.compile_time_connectivities
)(
area=self.cell_params.area,
geofac_n2s=self.interpolation_state.geofac_n2s,
geofac_grg_x=self.interpolation_state.geofac_grg_x,
geofac_grg_y=self.interpolation_state.geofac_grg_y,
w_old=self.w_tmp,
w=prognostic_state.w,
type_shear=gtx.int32(self.config.shear_type.value),
type_shear=int32(
self.config.shear_type.value
), # DaCe parser peculiarity (does not work as gtx.int32)
dwdx=diagnostic_state.dwdx,
dwdy=diagnostic_state.dwdy,
diff_multfac_w=self.diff_multfac_w,
diff_multfac_n2w=self.diff_multfac_n2w,
k=self.vertical_index,
cell=self.horizontal_cell_index,
nrdmax=gtx.int32(
nrdmax=int32( # DaCe parser peculiarity (does not work as gtx.int32)
self.vertical_grid.end_index_of_damping_layer + 1
), # +1 since Fortran includes boundaries
interior_idx=self._cell_start_interior,
halo_idx=self._cell_end_local,
horizontal_start=self._horizontal_start_index_w_diffusion,
horizontal_end=self._cell_end_halo,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug(
Expand All @@ -840,24 +868,26 @@ def _do_diffusion_step(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_connectivities(
self.compile_time_connectivities
)(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self.metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
smallest_vpfloat=constants.DBL_EPS,
kh_smag_e=self.kh_smag_e,
horizontal_start=self._edge_start_nudging,
horizontal_end=self._edge_end_halo,
vertical_start=(num_levels - 2),
vertical_end=num_levels,
vertical_start=(self.grid.num_levels - 2),
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug(
"running stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): end"
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
self.calculate_nabla2_for_theta(
self.calculate_nabla2_for_theta.with_connectivities(self.compile_time_connectivities)(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self.edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
Expand All @@ -866,15 +896,17 @@ def _do_diffusion_step(
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)
log.debug("running stencils 13_14 (calculate_nabla2_for_theta): end")
log.debug(
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
if self.config.apply_zdiffusion_t:
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_connectivities(
self.compile_time_connectivities
)(
mask=self.metric_state.mask_hdiff,
zd_vertoffset=self.metric_state.zd_vertoffset,
zd_diffcoef=self.metric_state.zd_diffcoef,
Expand All @@ -886,15 +918,15 @@ def _do_diffusion_step(
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider=self.grid.offset_providers,
)

log.debug(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
self.update_theta_and_exner(
self.update_theta_and_exner.with_connectivities(self.compile_time_connectivities)(
z_temp=self.z_temp,
area=self.cell_params.area,
theta_v=prognostic_state.theta_v,
Expand All @@ -903,9 +935,36 @@ def _do_diffusion_step(
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
vertical_end=self.grid.num_levels,
offset_provider={},
)
log.debug("running stencil 16 (update_theta_and_exner): end")
handle_edge_comm.wait() # need to do this here, since we currently only use 1 communication object.

self.halo_exchange_wait(
handle_edge_comm
) # need to do this here, since we currently only use 1 communication object.
log.debug("communication of prognogistic.vn - end")

# TODO (kotsaloscv): It is unsafe to set it as cached property -demands more testing-
def orchestration_uid(self) -> str:
"""Unique id based on the runtime state of the Diffusion object. It is used for caching in DaCe Orchestration."""
members_to_disregard = [
"_backend",
"_exchange",
"mo_intp_rbf_rbf_vec_interpol_vertex",
"calculate_nabla2_and_smag_coefficients_for_vn",
"calculate_diagnostic_quantities_for_turbulence",
"apply_diffusion_to_vn",
"apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence",
"calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools",
"calculate_nabla2_for_theta",
"truly_horizontal_diffusion_nabla_of_theta_over_steep_points",
"update_theta_and_exner",
"copy_field",
"scale_k",
"setup_fields_for_initial_step",
"init_diffusion_local_fields_for_regular_timestep",
]
return orchestration.generate_orchestration_uid(
self, members_to_disregard=members_to_disregard
)
Loading

0 comments on commit 8836483

Please sign in to comment.