Skip to content

Commit

Permalink
Intp fields factory others (#627)
Browse files Browse the repository at this point in the history
Factory implementation of left-over interpolation fields, namely:
- e_flx_avg
- e_bln_c_s
- pos_on_tplane_e_x_y

---------

Co-authored-by: Magdalena Luz <[email protected]>
  • Loading branch information
nfarabullini and halungge authored Jan 28, 2025
1 parent e61ca68 commit 8b6b344
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 89 deletions.
27 changes: 26 additions & 1 deletion model/common/src/icon4py/model/common/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
icon,
)
from icon4py.model.common.states import factory, model, utils as state_utils
from icon4py.model.common.utils import data_allocation as alloc


InputGeometryFieldType: TypeAlias = Literal[attrs.CELL_AREA, attrs.TANGENT_ORIENTATION]
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
"""
self._providers = {}
self._backend = backend
self._xp = alloc.import_array_ns(backend)
self._allocator = gtx.constructors.zeros.partial(allocator=backend)
self._grid = grid
self._decomposition_info = decomposition_info
Expand Down Expand Up @@ -257,7 +259,7 @@ def _register_computed_fields(self):
self.register_provider(coriolis_params)

# normals:
# 1. edges%primal_cart_normal (cartesian coordinates for primal_normal
# 1. edges%primal_cart_normal (cartesian coordinates for primal_normal)
tangent_normal_coordinates = factory.ProgramFieldProvider(
func=stencils.compute_cartesian_coordinates_of_edge_tangent_and_normal,
deps={
Expand All @@ -283,6 +285,7 @@ def _register_computed_fields(self):
},
)
self.register_provider(tangent_normal_coordinates)

# 2. primal_normals: gridfile%zonal_normal_primal_edge - edges%primal_normal%v1, gridfile%meridional_normal_primal_edge - edges%primal_normal%v2,
normal_uv = factory.ProgramFieldProvider(
func=math_helpers.compute_zonal_and_meridional_components_on_edges,
Expand All @@ -306,6 +309,28 @@ def _register_computed_fields(self):
)
self.register_provider(normal_uv)

dual_uv = factory.ProgramFieldProvider(
func=math_helpers.compute_zonal_and_meridional_components_on_edges,
deps={
"lat": attrs.EDGE_LAT,
"lon": attrs.EDGE_LON,
"x": attrs.EDGE_TANGENT_X,
"y": attrs.EDGE_TANGENT_Y,
"z": attrs.EDGE_TANGENT_Z,
},
fields={
"u": attrs.EDGE_DUAL_U,
"v": attrs.EDGE_DUAL_V,
},
domain={
dims.EdgeDim: (
self._edge_domain(h_grid.Zone.LOCAL),
self._edge_domain(h_grid.Zone.END),
)
},
)
self.register_provider(dual_uv)

# 3. primal_normal_vert, primal_normal_cell
normal_vert = factory.ProgramFieldProvider(
func=stencils.compute_zonal_and_meridional_component_of_edge_field_at_vertex,
Expand Down
18 changes: 18 additions & 0 deletions model/common/src/icon4py/model/common/grid/geometry_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
EDGE_NORMAL_Z: Final[str] = "z_component_of_edge_normal_unit_vector"
EDGE_NORMAL_U: Final[str] = "eastward_component_of_edge_normal"
EDGE_NORMAL_V: Final[str] = "northward_component_of_edge_normal"
EDGE_DUAL_U: Final[str] = "eastward_component_of_edge_tangent"
EDGE_DUAL_V: Final[str] = "northward_component_of_edge_tangent"
EDGE_NORMAL_VERTEX_U: Final[str] = "eastward_component_of_edge_normal_on_vertex"
EDGE_NORMAL_VERTEX_V: Final[str] = "northward_component_of_edge_normal_on_vertex"
EDGE_NORMAL_CELL_U: Final[str] = "eastward_component_of_edge_normal_on_cell"
Expand Down Expand Up @@ -327,6 +329,22 @@
icon_var_name="t_grid_vertex%edge_orientation",
dtype=ta.wpfloat,
),
EDGE_DUAL_U: dict(
standard_name=EDGE_DUAL_U,
long_name="eastward component of the dual edge (edge tangent)",
units="", # TODO
dims=(dims.EdgeDim,),
icon_var_name="ptr_patch%edges%dual_normal%v1",
dtype=ta.wpfloat,
),
EDGE_DUAL_V: dict(
standard_name="northward component of the dual edge (edge tangent)",
long_name="ptr_patch%edges%dual_normal_vert_y",
units="", # TODO
dims=(dims.EdgeDim,),
icon_var_name="ptr_patch%edges%dual_normal%v2",
dtype=ta.wpfloat,
),
}


Expand Down
22 changes: 22 additions & 0 deletions model/common/src/icon4py/model/common/grid/geometry_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from types import ModuleType

import numpy as np
from gt4py import next as gtx
from gt4py.next import sin, where

Expand All @@ -19,6 +22,7 @@
normalize_cartesian_vector_on_edges,
zonal_and_meridional_components_on_edges,
)
from icon4py.model.common.utils import data_allocation as alloc


@gtx.field_operator(grid_type=gtx.GridType.UNSTRUCTURED)
Expand Down Expand Up @@ -567,3 +571,21 @@ def compute_coriolis_parameter_on_edges(
out=coriolis_parameter,
domain={dims.EdgeDim: (horizontal_start, horizontal_end)},
)


def compute_primal_cart_normal(
primal_cart_normal_x: alloc.NDArray,
primal_cart_normal_y: alloc.NDArray,
primal_cart_normal_z: alloc.NDArray,
array_ns: ModuleType = np,
) -> alloc.NDArray:
primal_cart_normal = array_ns.transpose(
array_ns.stack(
(
primal_cart_normal_x,
primal_cart_normal_y,
primal_cart_normal_z,
)
)
)
return primal_cart_normal
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

C_LIN_E: Final[str] = "interpolation_coefficient_from_cell_to_edge"
C_BLN_AVG: Final[str] = "bilinear_cell_average_weight"
E_BLN_C_S: Final[str] = "bilinear_edge_cell_weight"
GEOFAC_DIV: Final[str] = "geometrical_factor_for_divergence"
GEOFAC_ROT: Final[str] = "geometrical_factor_for_curl"
GEOFAC_N2S: Final[str] = "geometrical_factor_for_nabla_2_scalar"
GEOFAC_GRDIV: Final[str] = "geometrical_factor_for_gradient_of_divergence"
GEOFAC_GRG_X: Final[str] = "geometrical_factor_for_green_gauss_gradient_x"
GEOFAC_GRG_Y: Final[str] = "geometrical_factor_for_green_gauss_gradient_y"
E_FLX_AVG: Final[str] = "e_flux_average"
POS_ON_TPLANE_E_X: Final[str] = "pos_on_tplane_e_x"
POS_ON_TPLANE_E_Y: Final[str] = "pos_on_tplane_e_y"
CELL_AW_VERTS: Final[str] = "cell_to_vertex_interpolation_factor_by_area_weighting"

attrs: dict[str, model.FieldMetaData] = {
Expand All @@ -39,6 +43,14 @@
icon_var_name="c_lin_e",
dtype=ta.wpfloat,
),
E_BLN_C_S: dict(
standard_name=E_BLN_C_S,
long_name="mass conserving bilinear edge cell weight",
units="", # TODO check or confirm
dims=(dims.CellDim, dims.C2EDim),
icon_var_name="e_bln_c_s",
dtype=ta.wpfloat,
),
GEOFAC_DIV: dict(
standard_name=GEOFAC_DIV,
long_name="geometrical factor for divergence", # TODO (@halungge) find proper description
Expand Down Expand Up @@ -87,6 +99,30 @@
icon_var_name="geofac_grg",
dtype=ta.wpfloat,
),
E_FLX_AVG: dict(
standard_name=E_FLX_AVG,
long_name="e flux average",
units="", # TODO check or confirm
dims=(dims.EdgeDim, dims.E2C2EODim),
icon_var_name="e_flx_avg",
dtype=ta.wpfloat,
),
POS_ON_TPLANE_E_X: dict(
standard_name=POS_ON_TPLANE_E_X,
long_name="position on tplane x",
units="", # TODO check or confirm
dims=(dims.ECDim,),
icon_var_name="pos_on_tplane_e_x",
dtype=ta.wpfloat,
),
POS_ON_TPLANE_E_Y: dict(
standard_name=POS_ON_TPLANE_E_Y,
long_name="position on tplane y",
units="", # TODO check or confirm
dims=(dims.ECDim,),
icon_var_name="pos_on_tplane_e_y",
dtype=ta.wpfloat,
),
CELL_AW_VERTS: dict(
standard_name=CELL_AW_VERTS,
long_name="coefficient for interpolation from cells to verts by area weighting",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gt4py.next as gtx
from gt4py.next import backend as gtx_backend

from icon4py.model.common import dimension as dims
from icon4py.model.common import constants, dimension as dims
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.grid import (
geometry,
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self._providers: dict[str, factory.FieldProvider] = {}
self._geometry = geometry_source
# TODO @halungge: Dummy config dict - to be replaced by real configuration
self._config = {"divavg_cntrwgt": 0.5}
self._config = {"divavg_cntrwgt": 0.5, "weighting_factor": 0.0}
self._register_computed_fields()

def __repr__(self):
Expand All @@ -73,6 +73,7 @@ def _register_computed_fields(self):
},
)
self.register_provider(geofac_div)

geofac_rot = factory.FieldOperatorProvider(
# needs to be computed on fieldview-embedded backend
func=interpolation_fields.compute_geofac_rot.with_backend(None),
Expand Down Expand Up @@ -165,6 +166,7 @@ def _register_computed_fields(self):
},
)
self.register_provider(c_lin_e)

geofac_grg = factory.NumpyFieldsProvider(
func=functools.partial(interpolation_fields.compute_geofac_grg, array_ns=self._xp),
fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y),
Expand All @@ -185,6 +187,79 @@ def _register_computed_fields(self):
)
self.register_provider(geofac_grg)

e_flx_avg = factory.NumpyFieldsProvider(
func=functools.partial(interpolation_fields.compute_e_flx_avg, array_ns=self._xp),
fields=(attrs.E_FLX_AVG,),
domain=(dims.EdgeDim, dims.E2C2EODim),
deps={
"c_bln_avg": attrs.C_BLN_AVG,
"geofac_div": attrs.GEOFAC_DIV,
"owner_mask": "edge_owner_mask",
"primal_cart_normal_x": geometry_attrs.EDGE_NORMAL_X,
"primal_cart_normal_y": geometry_attrs.EDGE_NORMAL_Y,
"primal_cart_normal_z": geometry_attrs.EDGE_NORMAL_Z,
},
connectivities={
"e2c": dims.E2CDim,
"c2e": dims.C2EDim,
"c2e2c": dims.C2E2CDim,
"e2c2e": dims.E2C2EDim,
},
params={
"horizontal_start_p3": self.grid.start_index(
edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_4)
),
"horizontal_start_p4": self.grid.start_index(
edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_5)
),
},
)
self.register_provider(e_flx_avg)

e_bln_c_s = factory.NumpyFieldsProvider(
func=functools.partial(interpolation_fields.compute_e_bln_c_s, array_ns=self._xp),
fields=(attrs.E_BLN_C_S,),
domain=(dims.CellDim, dims.C2EDim),
deps={
"cells_lat": geometry_attrs.CELL_LAT,
"cells_lon": geometry_attrs.CELL_LON,
"edges_lat": geometry_attrs.EDGE_LAT,
"edges_lon": geometry_attrs.EDGE_LON,
},
connectivities={"c2e": dims.C2EDim},
params={"weighting_factor": self._config["weighting_factor"]},
)
self.register_provider(e_bln_c_s)

pos_on_tplane_e_x_y = factory.NumpyFieldsProvider(
func=functools.partial(
interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp
),
fields=(attrs.POS_ON_TPLANE_E_X, attrs.POS_ON_TPLANE_E_Y),
domain=(dims.ECDim,),
deps={
"primal_normal_v1": geometry_attrs.EDGE_NORMAL_U,
"primal_normal_v2": geometry_attrs.EDGE_NORMAL_V,
"dual_normal_v1": geometry_attrs.EDGE_DUAL_U,
"dual_normal_v2": geometry_attrs.EDGE_DUAL_V,
"cells_lon": geometry_attrs.CELL_LON,
"cells_lat": geometry_attrs.CELL_LAT,
"edges_lon": geometry_attrs.EDGE_LON,
"edges_lat": geometry_attrs.EDGE_LAT,
"vertex_lon": geometry_attrs.VERTEX_LON,
"vertex_lat": geometry_attrs.VERTEX_LAT,
"owner_mask": "edge_owner_mask",
},
connectivities={"e2c": dims.E2CDim, "e2v": dims.E2VDim, "e2c2e": dims.E2C2EDim},
params={
"grid_sphere_radius": constants.EARTH_RADIUS,
"horizontal_start": self.grid.start_index(
edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2)
),
},
)
self.register_provider(pos_on_tplane_e_x_y)

cells_aw_verts = factory.NumpyFieldsProvider(
func=functools.partial(interpolation_fields.compute_cells_aw_verts, array_ns=self._xp),
fields=(attrs.CELL_AW_VERTS,),
Expand Down
Loading

0 comments on commit 8b6b344

Please sign in to comment.