Skip to content

Commit

Permalink
[Py2F]: Make dycore run in parallel (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals authored and muellch committed Jan 22, 2025
1 parent 7bff4ba commit 2a09cbc
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,13 @@ def __init__(
self._compute_theta_and_exner = compute_theta_and_exner.with_backend(self._backend)
self._compute_exner_from_rhotheta = compute_exner_from_rhotheta.with_backend(self._backend)
self._update_theta_v = update_theta_v.with_backend(self._backend)
self._compute_first_vertical_derivative = (
compute_first_vertical_derivative.with_backend(self._backend)
)
self._interpolate_to_half_levels_vp = (
interpolate_to_half_levels_vp.with_backend(self._backend)
self._compute_first_vertical_derivative = compute_first_vertical_derivative.with_backend(
self._backend
)
self._interpolate_to_surface = (
interpolate_to_surface.with_backend(self._backend)
self._interpolate_to_half_levels_vp = interpolate_to_half_levels_vp.with_backend(
self._backend
)
self._interpolate_to_surface = interpolate_to_surface.with_backend(self._backend)
self._init_two_cell_kdim_fields_with_zero_vp = (
init_two_cell_kdim_fields_with_zero_vp.with_backend(self._backend)
)
Expand Down Expand Up @@ -1029,7 +1027,7 @@ def run_predictor_step(
ntnd=self.ntl1,
cell_areas=self._cell_params.area,
w_now=w_now,
vn_now=vn_now
vn_now=vn_now,
)

# Precompute Rayleigh damping factor
Expand Down Expand Up @@ -1777,9 +1775,7 @@ def run_predictor_step(
offset_provider=self._grid.offset_providers,
)
log.debug("exchanging prognostic field 'w' and local field 'z_dwdz_dd'")
self._exchange.exchange_and_wait(
dims.CellDim, w_new, z_fields.z_dwdz_dd
)
self._exchange.exchange_and_wait(dims.CellDim, w_new, z_fields.z_dwdz_dd)
else:
log.debug("exchanging prognostic field 'w'")
self._exchange.exchange_and_wait(dims.CellDim, w_new)
Expand Down Expand Up @@ -1846,7 +1842,7 @@ def run_corrector_step(
ntnd=self.ntl2,
cell_areas=self._cell_params.area,
w_new=w_new,
vn_new=vn_new
vn_new=vn_new,
)

nvar = nnew
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,46 @@
from icon4py.model.atmosphere.dycore.stencils.add_extra_diffusion_for_normal_wind_tendency_approaching_cfl import (
add_extra_diffusion_for_normal_wind_tendency_approaching_cfl,
)
from icon4py.model.atmosphere.dycore.stencils.correct_contravariant_vertical_velocity import (
correct_contravariant_vertical_velocity,
from icon4py.model.atmosphere.dycore.stencils.add_extra_diffusion_for_w_con_approaching_cfl import (
add_extra_diffusion_for_w_con_approaching_cfl,
)
from icon4py.model.atmosphere.dycore.stencils.copy_cell_kdim_field_to_vp import (
copy_cell_kdim_field_to_vp,
from icon4py.model.atmosphere.dycore.stencils.add_interpolated_horizontal_advection_of_w import (
add_interpolated_horizontal_advection_of_w,
)
from icon4py.model.atmosphere.dycore.stencils.compute_advective_normal_wind_tendency import (
compute_advective_normal_wind_tendency,
)
from icon4py.model.atmosphere.dycore.stencils.compute_advective_vertical_wind_tendency import (
compute_advective_vertical_wind_tendency,
)
from icon4py.model.atmosphere.dycore.stencils.compute_contravariant_correction import (
compute_contravariant_correction,
)
from icon4py.model.atmosphere.dycore.stencils.add_interpolated_horizontal_advection_of_w import (
add_interpolated_horizontal_advection_of_w,
from icon4py.model.atmosphere.dycore.stencils.compute_horizontal_advection_term_for_vertical_velocity import (
compute_horizontal_advection_term_for_vertical_velocity,
)
from icon4py.model.atmosphere.dycore.stencils.compute_horizontal_kinetic_energy import (
compute_horizontal_kinetic_energy,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_half_levels_vp import (
interpolate_to_half_levels_vp,
from icon4py.model.atmosphere.dycore.stencils.compute_tangential_wind import compute_tangential_wind
from icon4py.model.atmosphere.dycore.stencils.copy_cell_kdim_field_to_vp import (
copy_cell_kdim_field_to_vp,
)
from icon4py.model.atmosphere.dycore.stencils.correct_contravariant_vertical_velocity import (
correct_contravariant_vertical_velocity,
)
from icon4py.model.atmosphere.dycore.stencils.init_cell_kdim_field_with_zero_vp import (
init_cell_kdim_field_with_zero_vp,
)
from icon4py.model.atmosphere.dycore.stencils.add_extra_diffusion_for_w_con_approaching_cfl import (
add_extra_diffusion_for_w_con_approaching_cfl,
)
from icon4py.model.atmosphere.dycore.stencils.compute_advective_normal_wind_tendency import (
compute_advective_normal_wind_tendency,
)
from icon4py.model.atmosphere.dycore.stencils.compute_horizontal_advection_term_for_vertical_velocity import (
compute_horizontal_advection_term_for_vertical_velocity,
)
from icon4py.model.atmosphere.dycore.stencils.compute_tangential_wind import compute_tangential_wind
from icon4py.model.atmosphere.dycore.stencils.interpolate_contravariant_vertical_velocity_to_full_levels import (
interpolate_contravariant_vertical_velocity_to_full_levels,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_cell_center import (
interpolate_to_cell_center,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_to_half_levels_vp import (
interpolate_to_half_levels_vp,
)
from icon4py.model.atmosphere.dycore.stencils.interpolate_vn_to_ie_and_compute_ekin_on_edges import (
interpolate_vn_to_ie_and_compute_ekin_on_edges,
)
Expand Down Expand Up @@ -144,17 +144,15 @@ def __init__(
self._correct_contravariant_vertical_velocity = (
correct_contravariant_vertical_velocity.with_backend(self._backend)
)
self._copy_cell_kdim_field_to_vp = (
copy_cell_kdim_field_to_vp.with_backend(self._backend)
)
self._compute_contravariant_correction = (
compute_contravariant_correction.with_backend(self._backend)
self._copy_cell_kdim_field_to_vp = copy_cell_kdim_field_to_vp.with_backend(self._backend)
self._compute_contravariant_correction = compute_contravariant_correction.with_backend(
self._backend
)
self._compute_horizontal_kinetic_energy = (
compute_horizontal_kinetic_energy.with_backend(self._backend)
self._compute_horizontal_kinetic_energy = compute_horizontal_kinetic_energy.with_backend(
self._backend
)
self._init_cell_kdim_field_with_zero_vp = (
init_cell_kdim_field_with_zero_vp.with_backend(self._backend)
self._init_cell_kdim_field_with_zero_vp = init_cell_kdim_field_with_zero_vp.with_backend(
self._backend
)
self._add_extra_diffusion_for_w_con_approaching_cfl = (
add_extra_diffusion_for_w_con_approaching_cfl.with_backend(self._backend)
Expand Down Expand Up @@ -246,7 +244,6 @@ def run_predictor_step(
cell_areas: fa.CellField[float],
w_now,
vn_now,

):
cfl_w_limit, scalfac_exdiff = self._scale_factors_by_dtime(dtime)

Expand Down Expand Up @@ -565,7 +562,7 @@ def run_corrector_step(
ntnd: int,
cell_areas: fa.CellField[float],
w_new,
vn_new
vn_new,
):
cfl_w_limit, scalfac_exdiff = self._scale_factors_by_dtime(dtime)

Expand Down
95 changes: 62 additions & 33 deletions tools/src/icon4pytools/py2fgen/wrappers/dycore_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from gt4py.next import common as gt4py_common
from icon4py.model.atmosphere.dycore import dycore_states, solve_nonhydro
from icon4py.model.common import dimension as dims, settings
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.dimension import (
C2E2CODim,
C2EDim,
Expand All @@ -57,8 +58,7 @@
from icon4py.model.common.grid import icon
from icon4py.model.common.grid.icon import GlobalGridParams
from icon4py.model.common.grid.vertical import VerticalGrid, VerticalGridConfig
from icon4py.model.common.settings import backend
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.settings import backend, parallel_run
from icon4py.model.common.test_utils.helpers import (
as_1D_sparse_field,
flatten_first_two_dims,
Expand All @@ -67,9 +67,13 @@

from icon4pytools.common.logger import setup_logger
from icon4pytools.py2fgen.wrappers import common as wrapper_common
from icon4pytools.py2fgen.wrappers.debug_utils import print_grid_decomp_info
from icon4pytools.py2fgen.wrappers.wrapper_dimension import (
CellGlobalIndexDim,
CellIndexDim,
EdgeGlobalIndexDim,
EdgeIndexDim,
VertexGlobalIndexDim,
VertexIndexDim,
)

Expand All @@ -78,6 +82,7 @@

dycore_wrapper_state = {
"profiler": cProfile.Profile(),
"exchange_runtime": definitions.ExchangeRuntime,
}


Expand Down Expand Up @@ -341,6 +346,7 @@ def solve_nh_init(
cell_geometry=cell_geometry,
owner_mask=c_owner_mask,
backend=backend,
exchange=dycore_wrapper_state["exchange_runtime"],
)


Expand Down Expand Up @@ -420,22 +426,6 @@ def solve_nh_run(
exner_dyn_incr=exner_dyn_incr,
)

# prognostic_state_nnow = PrognosticState(
# w=w_now,
# vn=vn_now,
# theta_v=theta_v_now,
# rho=rho_now,
# exner=exner_now,
# )
# prognostic_state_nnew = PrognosticState(
# w=w_new,
# vn=vn_new,
# theta_v=theta_v_new,
# rho=rho_new,
# exner=exner_new,
# )
# prognostic_state_ls = [prognostic_state_nnow, prognostic_state_nnew]

# adjust for Fortran indexes
nnow = nnow - 1
nnew = nnew - 1
Expand Down Expand Up @@ -468,21 +458,28 @@ def solve_nh_run(


def grid_init(
cell_starts: gt4py_common.Field[[CellIndexDim], gtx.int32],
cell_ends: gt4py_common.Field[[CellIndexDim], gtx.int32],
vertex_starts: gt4py_common.Field[[VertexIndexDim], gtx.int32],
vertex_ends: gt4py_common.Field[[VertexIndexDim], gtx.int32],
edge_starts: gt4py_common.Field[[EdgeIndexDim], gtx.int32],
edge_ends: gt4py_common.Field[[EdgeIndexDim], gtx.int32],
c2e: gt4py_common.Field[[dims.CellDim, dims.C2EDim], gtx.int32],
e2c: gt4py_common.Field[[dims.EdgeDim, dims.E2CDim], gtx.int32],
c2e2c: gt4py_common.Field[[dims.CellDim, dims.C2E2CDim], gtx.int32],
e2c2e: gt4py_common.Field[[dims.EdgeDim, dims.E2C2EDim], gtx.int32],
e2v: gt4py_common.Field[[dims.EdgeDim, dims.E2VDim], gtx.int32],
v2e: gt4py_common.Field[[dims.VertexDim, dims.V2EDim], gtx.int32],
v2c: gt4py_common.Field[[dims.VertexDim, dims.V2CDim], gtx.int32],
e2c2v: gt4py_common.Field[[dims.EdgeDim, dims.E2C2VDim], gtx.int32],
c2v: gt4py_common.Field[[dims.CellDim, dims.C2VDim], gtx.int32],
cell_starts: gtx.Field[gtx.Dims[CellIndexDim], gtx.int32],
cell_ends: gtx.Field[gtx.Dims[CellIndexDim], gtx.int32],
vertex_starts: gtx.Field[gtx.Dims[VertexIndexDim], gtx.int32],
vertex_ends: gtx.Field[gtx.Dims[VertexIndexDim], gtx.int32],
edge_starts: gtx.Field[gtx.Dims[EdgeIndexDim], gtx.int32],
edge_ends: gtx.Field[gtx.Dims[EdgeIndexDim], gtx.int32],
c2e: gtx.Field[gtx.Dims[dims.CellDim, dims.C2EDim], gtx.int32],
e2c: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], gtx.int32],
c2e2c: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2CDim], gtx.int32],
e2c2e: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2C2EDim], gtx.int32],
e2v: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2VDim], gtx.int32],
v2e: gtx.Field[gtx.Dims[dims.VertexDim, dims.V2EDim], gtx.int32],
v2c: gtx.Field[gtx.Dims[dims.VertexDim, dims.V2CDim], gtx.int32],
e2c2v: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2C2VDim], gtx.int32],
c2v: gtx.Field[gtx.Dims[dims.CellDim, dims.C2VDim], gtx.int32],
c_owner_mask: gtx.Field[[dims.CellDim], bool],
e_owner_mask: gtx.Field[[dims.EdgeDim], bool],
v_owner_mask: gtx.Field[[dims.VertexDim], bool],
c_glb_index: gtx.Field[[CellGlobalIndexDim], gtx.int32],
e_glb_index: gtx.Field[[EdgeGlobalIndexDim], gtx.int32],
v_glb_index: gtx.Field[[VertexGlobalIndexDim], gtx.int32],
comm_id: gtx.int32,
global_root: gtx.int32,
global_level: gtx.int32,
num_vertices: gtx.int32,
Expand Down Expand Up @@ -522,3 +519,35 @@ def grid_init(
e2c2v=e2c2v,
c2v=c2v,
)

if parallel_run:
(
processor_props,
decomposition_info,
exchange_runtime,
) = wrapper_common.construct_decomposition(
c_glb_index,
e_glb_index,
v_glb_index,
c_owner_mask,
e_owner_mask,
v_owner_mask,
num_cells,
num_edges,
num_vertices,
vertical_size,
comm_id,
)
print_grid_decomp_info(
dycore_wrapper_state["grid"],
processor_props,
decomposition_info,
num_cells,
num_edges,
num_vertices,
)
# set exchange runtime to MultiNodeExchange
dycore_wrapper_state["exchange_runtime"] = exchange_runtime
else:
# set exchange runtime to SingleNodeExchange
dycore_wrapper_state["exchange_runtime"] = definitions.SingleNodeExchange()

0 comments on commit 2a09cbc

Please sign in to comment.