Skip to content

Commit

Permalink
diffusion backend edits (#566)
Browse files Browse the repository at this point in the history
* diffusion backend edits
  • Loading branch information
nfarabullini authored Oct 14, 2024
1 parent 489a498 commit b38c748
Show file tree
Hide file tree
Showing 19 changed files with 159 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import gt4py.next as gtx

import icon4py.model.common.states.prognostic_state as prognostics
from gt4py.next import backend

from icon4py.model.atmosphere.diffusion import diffusion_utils, diffusion_states
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
copy_field,
Expand Down Expand Up @@ -348,8 +350,11 @@ class Diffusion:
"""Class that configures diffusion and does one diffusion step."""

def __init__(
self, exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange()
self,
backend: backend.Backend,
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
self._backend = backend
self._exchange = exchange
self._initialized = False
self.rd_o_cvd: float = constants.GAS_CONSTANT_DRY_AIR / (
Expand All @@ -373,6 +378,41 @@ def __init__(
self.cell_params: Optional[geometry.CellParams] = None
self._horizontal_start_index_w_diffusion: gtx.int32 = gtx.int32(0)

self.mo_intp_rbf_rbf_vec_interpol_vertex = mo_intp_rbf_rbf_vec_interpol_vertex.with_backend(
self._backend
)
self.calculate_nabla2_and_smag_coefficients_for_vn = (
calculate_nabla2_and_smag_coefficients_for_vn.with_backend(self._backend)
)
self.calculate_diagnostic_quantities_for_turbulence = (
calculate_diagnostic_quantities_for_turbulence.with_backend(self._backend)
)
self.apply_diffusion_to_vn = apply_diffusion_to_vn.with_backend(self._backend)
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = (
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence.with_backend(
self._backend
)
)
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = (
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_backend(
self._backend
)
)
self.calculate_nabla2_for_theta = calculate_nabla2_for_theta.with_backend(self._backend)
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points = (
truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_backend(self._backend)
)
self.update_theta_and_exner = update_theta_and_exner.with_backend(self._backend)
self.copy_field = copy_field.with_backend(self._backend)
self.scale_k = scale_k.with_backend(self._backend)
self.setup_fields_for_initial_step = setup_fields_for_initial_step.with_backend(
self._backend
)

self.init_diffusion_local_fields_for_regular_timestep = (
init_diffusion_local_fields_for_regular_timestep.with_backend(self._backend)
)

def init(
self,
grid: icon_grid.IconGrid,
Expand Down Expand Up @@ -419,7 +459,7 @@ def init(
self.smag_offset: float = 0.25 * params.K4 * config.substep_as_float
self.diff_multfac_w: float = min(1.0 / 48.0, params.K4W * config.substep_as_float)

init_diffusion_local_fields_for_regular_timestep(
self.init_diffusion_local_fields_for_regular_timestep(
params.K4,
config.substep_as_float,
*params.smagorinski_factor,
Expand Down Expand Up @@ -454,26 +494,51 @@ def initialized(self):
return self._initialized

def _allocate_temporary_fields(self):
self.diff_multfac_vn = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
self.diff_multfac_n2w = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
self.smag_limit = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
self.enh_smag_fac = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
self.u_vert = field_alloc.allocate_zero_field(dims.VertexDim, dims.KDim, grid=self.grid)
self.v_vert = field_alloc.allocate_zero_field(dims.VertexDim, dims.KDim, grid=self.grid)
self.kh_smag_e = field_alloc.allocate_zero_field(dims.EdgeDim, dims.KDim, grid=self.grid)
self.kh_smag_ec = field_alloc.allocate_zero_field(dims.EdgeDim, dims.KDim, grid=self.grid)
self.z_nabla2_e = field_alloc.allocate_zero_field(dims.EdgeDim, dims.KDim, grid=self.grid)
self.z_temp = field_alloc.allocate_zero_field(dims.CellDim, dims.KDim, grid=self.grid)
self.diff_multfac_smag = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
self.diff_multfac_vn = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
self.diff_multfac_n2w = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
self.smag_limit = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
self.enh_smag_fac = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
self.u_vert = field_alloc.allocate_zero_field(
dims.VertexDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.v_vert = field_alloc.allocate_zero_field(
dims.VertexDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.kh_smag_e = field_alloc.allocate_zero_field(
dims.EdgeDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.kh_smag_ec = field_alloc.allocate_zero_field(
dims.EdgeDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.z_nabla2_e = field_alloc.allocate_zero_field(
dims.EdgeDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.z_temp = field_alloc.allocate_zero_field(
dims.CellDim, dims.KDim, grid=self.grid, backend=self._backend
)
self.diff_multfac_smag = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
# TODO(Magdalena): this is KHalfDim
self.vertical_index = field_alloc.allocate_indices(
dims.KDim, grid=self.grid, is_halfdim=True
dims.KDim, grid=self.grid, is_halfdim=True, backend=self._backend
)
self.horizontal_cell_index = field_alloc.allocate_indices(
dims.CellDim, grid=self.grid, backend=self._backend
)
self.horizontal_cell_index = field_alloc.allocate_indices(dims.CellDim, grid=self.grid)
self.horizontal_edge_index = field_alloc.allocate_indices(dims.EdgeDim, grid=self.grid)
self.w_tmp = gtx.as_field(
(dims.CellDim, dims.KDim),
xp.zeros((self.grid.num_cells, self.grid.num_levels + 1), dtype=float),
self.horizontal_edge_index = field_alloc.allocate_indices(
dims.EdgeDim, grid=self.grid, backend=self._backend
)
self.w_tmp = field_alloc.allocate_zero_field(
dims.CellDim, dims.KDim, grid=self.grid, is_halfdim=True, backend=self._backend
)

def _determine_horizontal_domains(self):
Expand Down Expand Up @@ -528,23 +593,22 @@ def initial_run(
This run uses special values for diff_multfac_vn, smag_limit and smag_offset
"""
diff_multfac_vn = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
smag_limit = field_alloc.allocate_zero_field(dims.KDim, grid=self.grid)
diff_multfac_vn = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)
smag_limit = field_alloc.allocate_zero_field(
dims.KDim, grid=self.grid, backend=self._backend
)

setup_fields_for_initial_step(
self.setup_fields_for_initial_step(
self.params.K4,
self.config.hdiff_efdt_ratio,
diff_multfac_vn,
smag_limit,
offset_provider={},
)
self._do_diffusion_step(
diagnostic_state,
prognostic_state,
dtime,
diff_multfac_vn,
smag_limit,
0.0,
diagnostic_state, prognostic_state, dtime, diff_multfac_vn, smag_limit, 0.0
)
self._sync_cell_fields(prognostic_state)

Expand Down Expand Up @@ -608,10 +672,10 @@ def _do_diffusion_step(
"""
num_levels = self.grid.num_levels
# dtime dependent: enh_smag_factor,
scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})
self.scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})

log.debug("rbf interpolation 1: start")
mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=prognostic_state.vn,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -631,7 +695,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
calculate_nabla2_and_smag_coefficients_for_vn(
self.calculate_nabla2_and_smag_coefficients_for_vn(
diff_multfac_smag=self.diff_multfac_smag,
tangent_orientation=self.edge_params.tangent_orientation,
inv_primal_edge_length=self.edge_params.inverse_primal_edge_lengths,
Expand Down Expand Up @@ -662,7 +726,7 @@ def _do_diffusion_step(
log.debug(
"running stencils 02 03 (calculate_diagnostic_quantities_for_turbulence): start"
)
calculate_diagnostic_quantities_for_turbulence(
self.calculate_diagnostic_quantities_for_turbulence(
kh_smag_ec=self.kh_smag_ec,
vn=prognostic_state.vn,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
Expand All @@ -689,7 +753,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=self.z_nabla2_e,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -709,7 +773,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
apply_diffusion_to_vn(
self.apply_diffusion_to_vn(
u_vert=self.u_vert,
v_vert=self.v_vert,
primal_normal_vert_v1=self.edge_params.primal_normal_vert[0],
Expand Down Expand Up @@ -741,9 +805,9 @@ def _do_diffusion_step(
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start"
)
# TODO (magdalena) get rid of this copying. So far passing an empty buffer instead did not verify?
copy_field(prognostic_state.w, self.w_tmp, offset_provider={})
self.copy_field(prognostic_state.w, self.w_tmp, offset_provider={})

apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area=self.cell_params.area,
geofac_n2s=self.interpolation_state.geofac_n2s,
geofac_grg_x=self.interpolation_state.geofac_grg_x,
Expand Down Expand Up @@ -776,7 +840,7 @@ def _do_diffusion_step(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self.metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
Expand All @@ -793,7 +857,7 @@ def _do_diffusion_step(
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
calculate_nabla2_for_theta(
self.calculate_nabla2_for_theta(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self.edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
Expand All @@ -810,7 +874,7 @@ def _do_diffusion_step(
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
if self.config.apply_zdiffusion_t:
truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
mask=self.metric_state.mask_hdiff,
zd_vertoffset=self.metric_state.zd_vertoffset,
zd_diffcoef=self.metric_state.zd_diffcoef,
Expand All @@ -830,7 +894,7 @@ def _do_diffusion_step(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
update_theta_and_exner(
self.update_theta_and_exner(
z_temp=self.z_temp,
area=self.cell_params.area,
theta_v=prognostic_state.theta_v,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _identity_c_k(field: fa.CellKField[float]) -> fa.CellKField[float]:
return field


@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED, backend=backend)
@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED)
def copy_field(old_f: fa.CellKField[float], new_f: fa.CellKField[float]):
_identity_c_k(old_f, out=new_f)

Expand All @@ -36,7 +36,7 @@ def _scale_k(field: fa.KField[float], factor: float) -> fa.KField[float]:
return field * factor


@gtx.program(backend=backend)
@gtx.program
def scale_k(field: fa.KField[float], factor: float, scaled_field: fa.KField[float]):
_scale_k(field, factor, out=scaled_field)

Expand Down Expand Up @@ -77,7 +77,7 @@ def _setup_fields_for_initial_step(
return diff_multfac_vn, smag_limit


@gtx.program(backend=backend)
@gtx.program
def setup_fields_for_initial_step(
k4: float,
hdiff_efdt_ratio: float,
Expand Down Expand Up @@ -121,7 +121,7 @@ def _init_diffusion_local_fields_for_regular_timestemp(
)


@gtx.program(backend=backend)
@gtx.program
def init_diffusion_local_fields_for_regular_timestep(
k4: float,
dyn_substeps: float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla4 import _calculate_nabla4
from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import vpfloat, wpfloat


Expand Down Expand Up @@ -89,7 +88,7 @@ def _apply_diffusion_to_vn(
return vn


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def apply_diffusion_to_vn(
u_vert: fa.VertexKField[vpfloat],
v_vert: fa.VertexKField[vpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from icon4py.model.common import field_type_aliases as fa
from icon4py.model.common.dimension import C2E2CODim, CellDim, KDim
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import vpfloat, wpfloat


Expand Down Expand Up @@ -76,7 +75,7 @@ def _apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
return w, dwdx, dwdy


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area: fa.CellField[wpfloat],
geofac_n2s: gtx.Field[gtx.Dims[CellDim, C2E2CODim], wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
_temporary_fields_for_turbulence_diagnostics,
)
from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import vpfloat, wpfloat


Expand All @@ -36,7 +35,7 @@ def _calculate_diagnostic_quantities_for_turbulence(
return div_ic_vp, hdef_ic_vp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_diagnostic_quantities_for_turbulence(
kh_smag_ec: fa.EdgeKField[vpfloat],
vn: fa.EdgeKField[wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
_temporary_field_for_grid_point_cold_pools_enhancement,
)
from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import vpfloat, wpfloat


Expand All @@ -38,7 +37,7 @@ def _calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
return kh_smag_e_vp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v: fa.CellKField[wpfloat],
theta_ref_mc: fa.CellKField[vpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.dimension import E2C2V, E2ECV
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import vpfloat, wpfloat


Expand Down Expand Up @@ -147,7 +146,7 @@ def _calculate_nabla2_and_smag_coefficients_for_vn(
return kh_smag_e_vp, astype(kh_smag_ec_wp, vpfloat), z_nabla2_e_wp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_nabla2_and_smag_coefficients_for_vn(
diff_multfac_smag: gtx.Field[gtx.Dims[dims.KDim], vpfloat],
tangent_orientation: fa.EdgeField[wpfloat],
Expand Down
Loading

0 comments on commit b38c748

Please sign in to comment.