From 8b6b3442779f6eea1d8d24a33ca702f52a570e47 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 28 Jan 2025 08:47:04 +0100 Subject: [PATCH] Intp fields factory others (#627) Factory implementation of left-over interpolation fields, namely: - e_flx_avg - e_bln_c_s - pos_on_tplane_e_x_y --------- Co-authored-by: Magdalena Luz --- .../src/icon4py/model/common/grid/geometry.py | 27 ++- .../model/common/grid/geometry_attributes.py | 18 ++ .../model/common/grid/geometry_stencils.py | 22 +++ .../interpolation/interpolation_attributes.py | 36 ++++ .../interpolation/interpolation_factory.py | 79 +++++++- .../interpolation/interpolation_fields.py | 185 ++++++++++-------- .../icon4py/model/common/states/factory.py | 1 - .../test_interpolation_factory.py | 57 ++++++ .../test_interpolation_fields.py | 8 +- 9 files changed, 344 insertions(+), 89 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/geometry.py b/model/common/src/icon4py/model/common/grid/geometry.py index 210dc8cf4..5a0366cfa 100644 --- a/model/common/src/icon4py/model/common/grid/geometry.py +++ b/model/common/src/icon4py/model/common/grid/geometry.py @@ -29,6 +29,7 @@ icon, ) from icon4py.model.common.states import factory, model, utils as state_utils +from icon4py.model.common.utils import data_allocation as alloc InputGeometryFieldType: TypeAlias = Literal[attrs.CELL_AREA, attrs.TANGENT_ORIENTATION] @@ -94,6 +95,7 @@ def __init__( """ self._providers = {} self._backend = backend + self._xp = alloc.import_array_ns(backend) self._allocator = gtx.constructors.zeros.partial(allocator=backend) self._grid = grid self._decomposition_info = decomposition_info @@ -257,7 +259,7 @@ def _register_computed_fields(self): self.register_provider(coriolis_params) # normals: - # 1. edges%primal_cart_normal (cartesian coordinates for primal_normal + # 1. edges%primal_cart_normal (cartesian coordinates for primal_normal) tangent_normal_coordinates = factory.ProgramFieldProvider( func=stencils.compute_cartesian_coordinates_of_edge_tangent_and_normal, deps={ @@ -283,6 +285,7 @@ def _register_computed_fields(self): }, ) self.register_provider(tangent_normal_coordinates) + # 2. primal_normals: gridfile%zonal_normal_primal_edge - edges%primal_normal%v1, gridfile%meridional_normal_primal_edge - edges%primal_normal%v2, normal_uv = factory.ProgramFieldProvider( func=math_helpers.compute_zonal_and_meridional_components_on_edges, @@ -306,6 +309,28 @@ def _register_computed_fields(self): ) self.register_provider(normal_uv) + dual_uv = factory.ProgramFieldProvider( + func=math_helpers.compute_zonal_and_meridional_components_on_edges, + deps={ + "lat": attrs.EDGE_LAT, + "lon": attrs.EDGE_LON, + "x": attrs.EDGE_TANGENT_X, + "y": attrs.EDGE_TANGENT_Y, + "z": attrs.EDGE_TANGENT_Z, + }, + fields={ + "u": attrs.EDGE_DUAL_U, + "v": attrs.EDGE_DUAL_V, + }, + domain={ + dims.EdgeDim: ( + self._edge_domain(h_grid.Zone.LOCAL), + self._edge_domain(h_grid.Zone.END), + ) + }, + ) + self.register_provider(dual_uv) + # 3. primal_normal_vert, primal_normal_cell normal_vert = factory.ProgramFieldProvider( func=stencils.compute_zonal_and_meridional_component_of_edge_field_at_vertex, diff --git a/model/common/src/icon4py/model/common/grid/geometry_attributes.py b/model/common/src/icon4py/model/common/grid/geometry_attributes.py index c83865ab7..00479b0f5 100644 --- a/model/common/src/icon4py/model/common/grid/geometry_attributes.py +++ b/model/common/src/icon4py/model/common/grid/geometry_attributes.py @@ -56,6 +56,8 @@ EDGE_NORMAL_Z: Final[str] = "z_component_of_edge_normal_unit_vector" EDGE_NORMAL_U: Final[str] = "eastward_component_of_edge_normal" EDGE_NORMAL_V: Final[str] = "northward_component_of_edge_normal" +EDGE_DUAL_U: Final[str] = "eastward_component_of_edge_tangent" +EDGE_DUAL_V: Final[str] = "northward_component_of_edge_tangent" EDGE_NORMAL_VERTEX_U: Final[str] = "eastward_component_of_edge_normal_on_vertex" EDGE_NORMAL_VERTEX_V: Final[str] = "northward_component_of_edge_normal_on_vertex" EDGE_NORMAL_CELL_U: Final[str] = "eastward_component_of_edge_normal_on_cell" @@ -327,6 +329,22 @@ icon_var_name="t_grid_vertex%edge_orientation", dtype=ta.wpfloat, ), + EDGE_DUAL_U: dict( + standard_name=EDGE_DUAL_U, + long_name="eastward component of the dual edge (edge tangent)", + units="", # TODO + dims=(dims.EdgeDim,), + icon_var_name="ptr_patch%edges%dual_normal%v1", + dtype=ta.wpfloat, + ), + EDGE_DUAL_V: dict( + standard_name="northward component of the dual edge (edge tangent)", + long_name="ptr_patch%edges%dual_normal_vert_y", + units="", # TODO + dims=(dims.EdgeDim,), + icon_var_name="ptr_patch%edges%dual_normal%v2", + dtype=ta.wpfloat, + ), } diff --git a/model/common/src/icon4py/model/common/grid/geometry_stencils.py b/model/common/src/icon4py/model/common/grid/geometry_stencils.py index 670919455..edfb22416 100644 --- a/model/common/src/icon4py/model/common/grid/geometry_stencils.py +++ b/model/common/src/icon4py/model/common/grid/geometry_stencils.py @@ -6,6 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from types import ModuleType + +import numpy as np from gt4py import next as gtx from gt4py.next import sin, where @@ -19,6 +22,7 @@ normalize_cartesian_vector_on_edges, zonal_and_meridional_components_on_edges, ) +from icon4py.model.common.utils import data_allocation as alloc @gtx.field_operator(grid_type=gtx.GridType.UNSTRUCTURED) @@ -567,3 +571,21 @@ def compute_coriolis_parameter_on_edges( out=coriolis_parameter, domain={dims.EdgeDim: (horizontal_start, horizontal_end)}, ) + + +def compute_primal_cart_normal( + primal_cart_normal_x: alloc.NDArray, + primal_cart_normal_y: alloc.NDArray, + primal_cart_normal_z: alloc.NDArray, + array_ns: ModuleType = np, +) -> alloc.NDArray: + primal_cart_normal = array_ns.transpose( + array_ns.stack( + ( + primal_cart_normal_x, + primal_cart_normal_y, + primal_cart_normal_z, + ) + ) + ) + return primal_cart_normal diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_attributes.py b/model/common/src/icon4py/model/common/interpolation/interpolation_attributes.py index 3c8fa4fa8..862f47815 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_attributes.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_attributes.py @@ -14,12 +14,16 @@ C_LIN_E: Final[str] = "interpolation_coefficient_from_cell_to_edge" C_BLN_AVG: Final[str] = "bilinear_cell_average_weight" +E_BLN_C_S: Final[str] = "bilinear_edge_cell_weight" GEOFAC_DIV: Final[str] = "geometrical_factor_for_divergence" GEOFAC_ROT: Final[str] = "geometrical_factor_for_curl" GEOFAC_N2S: Final[str] = "geometrical_factor_for_nabla_2_scalar" GEOFAC_GRDIV: Final[str] = "geometrical_factor_for_gradient_of_divergence" GEOFAC_GRG_X: Final[str] = "geometrical_factor_for_green_gauss_gradient_x" GEOFAC_GRG_Y: Final[str] = "geometrical_factor_for_green_gauss_gradient_y" +E_FLX_AVG: Final[str] = "e_flux_average" +POS_ON_TPLANE_E_X: Final[str] = "pos_on_tplane_e_x" +POS_ON_TPLANE_E_Y: Final[str] = "pos_on_tplane_e_y" CELL_AW_VERTS: Final[str] = "cell_to_vertex_interpolation_factor_by_area_weighting" attrs: dict[str, model.FieldMetaData] = { @@ -39,6 +43,14 @@ icon_var_name="c_lin_e", dtype=ta.wpfloat, ), + E_BLN_C_S: dict( + standard_name=E_BLN_C_S, + long_name="mass conserving bilinear edge cell weight", + units="", # TODO check or confirm + dims=(dims.CellDim, dims.C2EDim), + icon_var_name="e_bln_c_s", + dtype=ta.wpfloat, + ), GEOFAC_DIV: dict( standard_name=GEOFAC_DIV, long_name="geometrical factor for divergence", # TODO (@halungge) find proper description @@ -87,6 +99,30 @@ icon_var_name="geofac_grg", dtype=ta.wpfloat, ), + E_FLX_AVG: dict( + standard_name=E_FLX_AVG, + long_name="e flux average", + units="", # TODO check or confirm + dims=(dims.EdgeDim, dims.E2C2EODim), + icon_var_name="e_flx_avg", + dtype=ta.wpfloat, + ), + POS_ON_TPLANE_E_X: dict( + standard_name=POS_ON_TPLANE_E_X, + long_name="position on tplane x", + units="", # TODO check or confirm + dims=(dims.ECDim,), + icon_var_name="pos_on_tplane_e_x", + dtype=ta.wpfloat, + ), + POS_ON_TPLANE_E_Y: dict( + standard_name=POS_ON_TPLANE_E_Y, + long_name="position on tplane y", + units="", # TODO check or confirm + dims=(dims.ECDim,), + icon_var_name="pos_on_tplane_e_y", + dtype=ta.wpfloat, + ), CELL_AW_VERTS: dict( standard_name=CELL_AW_VERTS, long_name="coefficient for interpolation from cells to verts by area weighting", diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py index 28aa65241..306c63e89 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next import backend as gtx_backend -from icon4py.model.common import dimension as dims +from icon4py.model.common import constants, dimension as dims from icon4py.model.common.decomposition import definitions from icon4py.model.common.grid import ( geometry, @@ -50,7 +50,7 @@ def __init__( self._providers: dict[str, factory.FieldProvider] = {} self._geometry = geometry_source # TODO @halungge: Dummy config dict - to be replaced by real configuration - self._config = {"divavg_cntrwgt": 0.5} + self._config = {"divavg_cntrwgt": 0.5, "weighting_factor": 0.0} self._register_computed_fields() def __repr__(self): @@ -73,6 +73,7 @@ def _register_computed_fields(self): }, ) self.register_provider(geofac_div) + geofac_rot = factory.FieldOperatorProvider( # needs to be computed on fieldview-embedded backend func=interpolation_fields.compute_geofac_rot.with_backend(None), @@ -165,6 +166,7 @@ def _register_computed_fields(self): }, ) self.register_provider(c_lin_e) + geofac_grg = factory.NumpyFieldsProvider( func=functools.partial(interpolation_fields.compute_geofac_grg, array_ns=self._xp), fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y), @@ -185,6 +187,79 @@ def _register_computed_fields(self): ) self.register_provider(geofac_grg) + e_flx_avg = factory.NumpyFieldsProvider( + func=functools.partial(interpolation_fields.compute_e_flx_avg, array_ns=self._xp), + fields=(attrs.E_FLX_AVG,), + domain=(dims.EdgeDim, dims.E2C2EODim), + deps={ + "c_bln_avg": attrs.C_BLN_AVG, + "geofac_div": attrs.GEOFAC_DIV, + "owner_mask": "edge_owner_mask", + "primal_cart_normal_x": geometry_attrs.EDGE_NORMAL_X, + "primal_cart_normal_y": geometry_attrs.EDGE_NORMAL_Y, + "primal_cart_normal_z": geometry_attrs.EDGE_NORMAL_Z, + }, + connectivities={ + "e2c": dims.E2CDim, + "c2e": dims.C2EDim, + "c2e2c": dims.C2E2CDim, + "e2c2e": dims.E2C2EDim, + }, + params={ + "horizontal_start_p3": self.grid.start_index( + edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_4) + ), + "horizontal_start_p4": self.grid.start_index( + edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_5) + ), + }, + ) + self.register_provider(e_flx_avg) + + e_bln_c_s = factory.NumpyFieldsProvider( + func=functools.partial(interpolation_fields.compute_e_bln_c_s, array_ns=self._xp), + fields=(attrs.E_BLN_C_S,), + domain=(dims.CellDim, dims.C2EDim), + deps={ + "cells_lat": geometry_attrs.CELL_LAT, + "cells_lon": geometry_attrs.CELL_LON, + "edges_lat": geometry_attrs.EDGE_LAT, + "edges_lon": geometry_attrs.EDGE_LON, + }, + connectivities={"c2e": dims.C2EDim}, + params={"weighting_factor": self._config["weighting_factor"]}, + ) + self.register_provider(e_bln_c_s) + + pos_on_tplane_e_x_y = factory.NumpyFieldsProvider( + func=functools.partial( + interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp + ), + fields=(attrs.POS_ON_TPLANE_E_X, attrs.POS_ON_TPLANE_E_Y), + domain=(dims.ECDim,), + deps={ + "primal_normal_v1": geometry_attrs.EDGE_NORMAL_U, + "primal_normal_v2": geometry_attrs.EDGE_NORMAL_V, + "dual_normal_v1": geometry_attrs.EDGE_DUAL_U, + "dual_normal_v2": geometry_attrs.EDGE_DUAL_V, + "cells_lon": geometry_attrs.CELL_LON, + "cells_lat": geometry_attrs.CELL_LAT, + "edges_lon": geometry_attrs.EDGE_LON, + "edges_lat": geometry_attrs.EDGE_LAT, + "vertex_lon": geometry_attrs.VERTEX_LON, + "vertex_lat": geometry_attrs.VERTEX_LAT, + "owner_mask": "edge_owner_mask", + }, + connectivities={"e2c": dims.E2CDim, "e2v": dims.E2VDim, "e2c2e": dims.E2C2EDim}, + params={ + "grid_sphere_radius": constants.EARTH_RADIUS, + "horizontal_start": self.grid.start_index( + edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2) + ), + }, + ) + self.register_provider(pos_on_tplane_e_x_y) + cells_aw_verts = factory.NumpyFieldsProvider( func=functools.partial(interpolation_fields.compute_cells_aw_verts, array_ns=self._xp), fields=(attrs.CELL_AW_VERTS,), diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_fields.py b/model/common/src/icon4py/model/common/interpolation/interpolation_fields.py index 3d0c85113..5026f9aab 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_fields.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_fields.py @@ -19,6 +19,7 @@ from icon4py.model.common import dimension as dims from icon4py.model.common.dimension import C2E, V2E from icon4py.model.common.grid import grid_manager as gm +from icon4py.model.common.grid.geometry_stencils import compute_primal_cart_normal from icon4py.model.common.utils import data_allocation as data_alloc @@ -212,8 +213,12 @@ def _compute_geofac_grg( inverse_neighbor = create_inverse_neighbor_index(e2c, c2e, array_ns) tmp = geofac_div * c_lin_e[c2e, inverse_neighbor] - geofac_grg_x[horizontal_start:, 0] = np.sum(primal_normal_ec_u * tmp, axis=1)[horizontal_start:] - geofac_grg_y[horizontal_start:, 0] = np.sum(primal_normal_ec_v * tmp, axis=1)[horizontal_start:] + geofac_grg_x[horizontal_start:, 0] = array_ns.sum(primal_normal_ec_u * tmp, axis=1)[ + horizontal_start: + ] + geofac_grg_y[horizontal_start:, 0] = array_ns.sum(primal_normal_ec_v * tmp, axis=1)[ + horizontal_start: + ] for k in range(e2c.shape[1]): mask = (e2c[c2e, k] == c2e2c)[horizontal_start:, :] @@ -554,7 +559,7 @@ def _apply_correction( local_weight = array_ns.sum(c_bln_avg, axis=1) - 1.0 c_bln_avg[horizontal_start:, :] = c_bln_avg[horizontal_start:, :] - ( - 0.25 * local_weight[horizontal_start:, np.newaxis] + 0.25 * local_weight[horizontal_start:, array_ns.newaxis] ) # avoid runaway condition: @@ -672,17 +677,20 @@ def create_inverse_neighbor_index(source_offset, inverse_offset, array_ns: Modul # TODO (@halungge) this can be simplified using only def compute_e_flx_avg( - c_bln_avg: np.ndarray, - geofac_div: np.ndarray, - owner_mask: np.ndarray, - primal_cart_normal: np.ndarray, - e2c: np.ndarray, - c2e: np.ndarray, - c2e2c: np.ndarray, - e2c2e: np.ndarray, + c_bln_avg: data_alloc.NDArray, + geofac_div: data_alloc.NDArray, + owner_mask: data_alloc.NDArray, + primal_cart_normal_x: data_alloc.NDArray, + primal_cart_normal_y: data_alloc.NDArray, + primal_cart_normal_z: data_alloc.NDArray, + e2c: data_alloc.NDArray, + c2e: data_alloc.NDArray, + c2e2c: data_alloc.NDArray, + e2c2e: data_alloc.NDArray, horizontal_start_p3: np.int32, horizontal_start_p4: np.int32, -) -> np.ndarray: + array_ns: ModuleType = np, +) -> data_alloc.NDArray: """ Compute edge flux average @@ -701,72 +709,84 @@ def compute_e_flx_avg( Returns: e_flx_avg: numpy array, representing a gtx.Field[gtx.Dims[EdgeDim, E2C2EODim], ta.wpfloat] """ + primal_cart_normal = compute_primal_cart_normal( + primal_cart_normal_x, + primal_cart_normal_y, + primal_cart_normal_z, + array_ns=array_ns, + ) + llb = 0 - e_flx_avg = np.zeros([e2c.shape[0], 5]) - index = np.arange(llb, c2e.shape[0]) - inv_neighbor_id = -np.ones([c2e.shape[0] - llb, 3], dtype=int) + e_flx_avg = array_ns.zeros([e2c.shape[0], 5]) + index = array_ns.arange(llb, c2e.shape[0]) + inv_neighbor_id = -array_ns.ones([c2e.shape[0] - llb, 3], dtype=int) for i in range(c2e2c.shape[1]): for j in range(c2e2c.shape[1]): - inv_neighbor_id[:, j] = np.where( - np.logical_and(c2e2c[c2e2c[llb:, j], i] == index, c2e2c[llb:, j] >= 0), + inv_neighbor_id[:, j] = array_ns.where( + array_ns.logical_and(c2e2c[c2e2c[llb:, j], i] == index, c2e2c[llb:, j] >= 0), i, inv_neighbor_id[:, j], ) llb = horizontal_start_p3 - index = np.arange(llb, e2c.shape[0]) + index = array_ns.arange(llb, e2c.shape[0]) for j in range(c2e.shape[1]): for i in range(2): - e_flx_avg[llb:, i + 1] = np.where( + e_flx_avg[llb:, i + 1] = array_ns.where( owner_mask[llb:], - np.where( + array_ns.where( c2e[e2c[llb:, 0], j] == index, c_bln_avg[e2c[llb:, 1], inv_neighbor_id[e2c[llb:, 0], j] + 1] - * geofac_div[e2c[llb:, 0], np.mod(i + j + 1, 3)] + * geofac_div[e2c[llb:, 0], array_ns.mod(i + j + 1, 3)] / geofac_div[e2c[llb:, 1], inv_neighbor_id[e2c[llb:, 0], j]], e_flx_avg[llb:, i + 1], ), e_flx_avg[llb:, i + 1], ) - e_flx_avg[llb:, i + 3] = np.where( + e_flx_avg[llb:, i + 3] = array_ns.where( owner_mask[llb:], - np.where( + array_ns.where( c2e[e2c[llb:, 0], j] == index, c_bln_avg[e2c[llb:, 0], 1 + j] - * geofac_div[e2c[llb:, 1], np.mod(inv_neighbor_id[e2c[llb:, 0], j] + i + 1, 3)] + * geofac_div[ + e2c[llb:, 1], array_ns.mod(inv_neighbor_id[e2c[llb:, 0], j] + i + 1, 3) + ] / geofac_div[e2c[llb:, 0], j], e_flx_avg[llb:, i + 3], ), e_flx_avg[llb:, i + 3], ) - iie = -np.ones([e2c.shape[0], 4], dtype=int) - iie[:, 0] = np.where(e2c[e2c2e[:, 0], 0] == e2c[:, 0], 2, -1) - iie[:, 0] = np.where( - np.logical_and(e2c[e2c2e[:, 0], 1] == e2c[:, 0], iie[:, 0] != 2), 4, iie[:, 0] + iie = -array_ns.ones([e2c.shape[0], 4], dtype=int) + iie[:, 0] = array_ns.where(e2c[e2c2e[:, 0], 0] == e2c[:, 0], 2, -1) + iie[:, 0] = array_ns.where( + array_ns.logical_and(e2c[e2c2e[:, 0], 1] == e2c[:, 0], iie[:, 0] != 2), 4, iie[:, 0] ) - iie[:, 1] = np.where(e2c[e2c2e[:, 1], 0] == e2c[:, 0], 1, -1) - iie[:, 1] = np.where( - np.logical_and(e2c[e2c2e[:, 1], 1] == e2c[:, 0], iie[:, 1] != 1), 3, iie[:, 1] + iie[:, 1] = array_ns.where(e2c[e2c2e[:, 1], 0] == e2c[:, 0], 1, -1) + iie[:, 1] = array_ns.where( + array_ns.logical_and(e2c[e2c2e[:, 1], 1] == e2c[:, 0], iie[:, 1] != 1), 3, iie[:, 1] ) - iie[:, 2] = np.where(e2c[e2c2e[:, 2], 0] == e2c[:, 1], 2, -1) - iie[:, 2] = np.where( - np.logical_and(e2c[e2c2e[:, 2], 1] == e2c[:, 1], iie[:, 2] != 2), 4, iie[:, 2] + iie[:, 2] = array_ns.where(e2c[e2c2e[:, 2], 0] == e2c[:, 1], 2, -1) + iie[:, 2] = array_ns.where( + array_ns.logical_and(e2c[e2c2e[:, 2], 1] == e2c[:, 1], iie[:, 2] != 2), 4, iie[:, 2] ) - iie[:, 3] = np.where(e2c[e2c2e[:, 3], 0] == e2c[:, 1], 1, -1) - iie[:, 3] = np.where( - np.logical_and(e2c[e2c2e[:, 3], 1] == e2c[:, 1], iie[:, 3] != 1), 3, iie[:, 3] + iie[:, 3] = array_ns.where(e2c[e2c2e[:, 3], 0] == e2c[:, 1], 1, -1) + iie[:, 3] = array_ns.where( + array_ns.logical_and(e2c[e2c2e[:, 3], 1] == e2c[:, 1], iie[:, 3] != 1), 3, iie[:, 3] ) llb = horizontal_start_p4 - index = np.arange(llb, e2c.shape[0]) + index = array_ns.arange(llb, e2c.shape[0]) for i in range(c2e.shape[1]): - e_flx_avg[llb:, 0] = np.where( + # INVALID_INDEX + if i <= gm.GridFile.INVALID_INDEX: + continue + e_flx_avg[llb:, 0] = array_ns.where( owner_mask[llb:], - np.where( + array_ns.where( c2e[e2c[llb:, 0], i] == index, 0.5 * ( @@ -775,9 +795,9 @@ def compute_e_flx_avg( + geofac_div[e2c[llb:, 1], inv_neighbor_id[e2c[llb:, 0], i]] * c_bln_avg[e2c[llb:, 0], i + 1] - e_flx_avg[e2c2e[llb:, 0], iie[llb:, 0]] - * geofac_div[e2c[llb:, 0], np.mod(i + 1, 3)] + * geofac_div[e2c[llb:, 0], array_ns.mod(i + 1, 3)] - e_flx_avg[e2c2e[llb:, 1], iie[llb:, 1]] - * geofac_div[e2c[llb:, 0], np.mod(i + 2, 3)] + * geofac_div[e2c[llb:, 0], array_ns.mod(i + 2, 3)] ) / geofac_div[e2c[llb:, 0], i] + ( @@ -786,9 +806,13 @@ def compute_e_flx_avg( + geofac_div[e2c[llb:, 0], i] * c_bln_avg[e2c[llb:, 1], inv_neighbor_id[e2c[llb:, 0], i] + 1] - e_flx_avg[e2c2e[llb:, 2], iie[llb:, 2]] - * geofac_div[e2c[llb:, 1], np.mod(inv_neighbor_id[e2c[llb:, 0], i] + 1, 3)] + * geofac_div[ + e2c[llb:, 1], array_ns.mod(inv_neighbor_id[e2c[llb:, 0], i] + 1, 3) + ] - e_flx_avg[e2c2e[llb:, 3], iie[llb:, 3]] - * geofac_div[e2c[llb:, 1], np.mod(inv_neighbor_id[e2c[llb:, 0], i] + 2, 3)] + * geofac_div[ + e2c[llb:, 1], array_ns.mod(inv_neighbor_id[e2c[llb:, 0], i] + 2, 3) + ] ) / geofac_div[e2c[llb:, 1], inv_neighbor_id[e2c[llb:, 0], i]] ), @@ -801,12 +825,12 @@ def compute_e_flx_avg( for i in range(4): checksum = ( checksum - + np.sum(primal_cart_normal * primal_cart_normal[e2c2e[:, i], :], axis=1) + + array_ns.sum(primal_cart_normal * primal_cart_normal[e2c2e[:, i], :], axis=1) * e_flx_avg[:, 1 + i] ) for i in range(5): - e_flx_avg[llb:, i] = np.where( + e_flx_avg[llb:, i] = array_ns.where( owner_mask[llb:], e_flx_avg[llb:, i] / checksum[llb:], e_flx_avg[llb:, i] ) @@ -931,22 +955,23 @@ def compute_e_bln_c_s( def compute_pos_on_tplane_e_x_y( grid_sphere_radius: ta.wpfloat, - primal_normal_v1: np.ndarray, - primal_normal_v2: np.ndarray, - dual_normal_v1: np.ndarray, - dual_normal_v2: np.ndarray, - cells_lon: np.ndarray, - cells_lat: np.ndarray, - edges_lon: np.ndarray, - edges_lat: np.ndarray, - vertex_lon: np.ndarray, - vertex_lat: np.ndarray, - owner_mask: np.ndarray, - e2c: np.ndarray, - e2v: np.ndarray, - e2c2e: np.ndarray, + primal_normal_v1: data_alloc.NDArray, + primal_normal_v2: data_alloc.NDArray, + dual_normal_v1: data_alloc.NDArray, + dual_normal_v2: data_alloc.NDArray, + cells_lon: data_alloc.NDArray, + cells_lat: data_alloc.NDArray, + edges_lon: data_alloc.NDArray, + edges_lat: data_alloc.NDArray, + vertex_lon: data_alloc.NDArray, + vertex_lat: data_alloc.NDArray, + owner_mask: data_alloc.NDArray, + e2c: data_alloc.NDArray, + e2v: data_alloc.NDArray, + e2c2e: data_alloc.NDArray, horizontal_start: np.int32, -) -> np.ndarray: + array_ns: ModuleType = np, +) -> data_alloc.NDArray: """ Compute pos_on_tplane_e_x_y. get geographical coordinates of edge midpoint @@ -979,9 +1004,9 @@ def compute_pos_on_tplane_e_x_y( pos_on_tplane_e_y: // """ llb = horizontal_start - pos_on_tplane_e = np.zeros([e2c.shape[0], 8, 2]) - xyloc_plane_n1 = np.zeros([2, e2c.shape[0]]) - xyloc_plane_n2 = np.zeros([2, e2c.shape[0]]) + pos_on_tplane_e = array_ns.zeros([e2c.shape[0], 8, 2]) + xyloc_plane_n1 = array_ns.zeros([2, e2c.shape[0]]) + xyloc_plane_n2 = array_ns.zeros([2, e2c.shape[0]]) xyloc_plane_n1[0, llb:], xyloc_plane_n1[1, llb:] = proj.gnomonic_proj( edges_lon[llb:], edges_lat[llb:], cells_lon[e2c[llb:, 0]], cells_lat[e2c[llb:, 0]] ) @@ -989,8 +1014,8 @@ def compute_pos_on_tplane_e_x_y( edges_lon[llb:], edges_lat[llb:], cells_lon[e2c[llb:, 1]], cells_lat[e2c[llb:, 1]] ) - xyloc_quad = np.zeros([4, 2, e2c.shape[0]]) - xyloc_plane_quad = np.zeros([4, 2, e2c.shape[0]]) + xyloc_quad = array_ns.zeros([4, 2, e2c.shape[0]]) + xyloc_plane_quad = array_ns.zeros([4, 2, e2c.shape[0]]) for ne in range(4): xyloc_quad[ne, 0, llb:] = edges_lon[e2c2e[llb:, ne]] xyloc_quad[ne, 1, llb:] = edges_lat[e2c2e[llb:, ne]] @@ -998,8 +1023,8 @@ def compute_pos_on_tplane_e_x_y( edges_lon[llb:], edges_lat[llb:], xyloc_quad[ne, 0, llb:], xyloc_quad[ne, 1, llb:] ) - xyloc_ve = np.zeros([2, 2, e2c.shape[0]]) - xyloc_plane_ve = np.zeros([2, 2, e2c.shape[0]]) + xyloc_ve = array_ns.zeros([2, 2, e2c.shape[0]]) + xyloc_plane_ve = array_ns.zeros([2, 2, e2c.shape[0]]) for nv in range(2): xyloc_ve[nv, 0, llb:] = vertex_lon[e2v[llb:, nv]] xyloc_ve[nv, 1, llb:] = vertex_lat[e2v[llb:, nv]] @@ -1007,7 +1032,7 @@ def compute_pos_on_tplane_e_x_y( edges_lon[llb:], edges_lat[llb:], xyloc_ve[nv, 0, llb:], xyloc_ve[nv, 1, llb:] ) - pos_on_tplane_e[llb:, 0, 0] = np.where( + pos_on_tplane_e[llb:, 0, 0] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1016,7 +1041,7 @@ def compute_pos_on_tplane_e_x_y( ), pos_on_tplane_e[llb:, 0, 0], ) - pos_on_tplane_e[llb:, 0, 1] = np.where( + pos_on_tplane_e[llb:, 0, 1] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1025,7 +1050,7 @@ def compute_pos_on_tplane_e_x_y( ), pos_on_tplane_e[llb:, 0, 1], ) - pos_on_tplane_e[llb:, 1, 0] = np.where( + pos_on_tplane_e[llb:, 1, 0] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1034,7 +1059,7 @@ def compute_pos_on_tplane_e_x_y( ), pos_on_tplane_e[llb:, 1, 0], ) - pos_on_tplane_e[llb:, 1, 1] = np.where( + pos_on_tplane_e[llb:, 1, 1] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1045,7 +1070,7 @@ def compute_pos_on_tplane_e_x_y( ) for ne in range(4): - pos_on_tplane_e[llb:, 2 + ne, 0] = np.where( + pos_on_tplane_e[llb:, 2 + ne, 0] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1054,7 +1079,7 @@ def compute_pos_on_tplane_e_x_y( ), pos_on_tplane_e[llb:, 2 + ne, 0], ) - pos_on_tplane_e[llb:, 2 + ne, 1] = np.where( + pos_on_tplane_e[llb:, 2 + ne, 1] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1065,7 +1090,7 @@ def compute_pos_on_tplane_e_x_y( ) for nv in range(2): - pos_on_tplane_e[llb:, 6 + nv, 0] = np.where( + pos_on_tplane_e[llb:, 6 + nv, 0] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1074,7 +1099,7 @@ def compute_pos_on_tplane_e_x_y( ), pos_on_tplane_e[llb:, 6 + nv, 0], ) - pos_on_tplane_e[llb:, 6 + nv, 1] = np.where( + pos_on_tplane_e[llb:, 6 + nv, 1] = array_ns.where( owner_mask[llb:], grid_sphere_radius * ( @@ -1084,10 +1109,10 @@ def compute_pos_on_tplane_e_x_y( pos_on_tplane_e[llb:, 6 + nv, 1], ) - pos_on_tplane_e_x = np.reshape( - pos_on_tplane_e[:, 0:2, 0], (np.size(pos_on_tplane_e[:, 0:2, 0])) + pos_on_tplane_e_x = array_ns.reshape( + pos_on_tplane_e[:, 0:2, 0], (array_ns.size(pos_on_tplane_e[:, 0:2, 0])) ) - pos_on_tplane_e_y = np.reshape( - pos_on_tplane_e[:, 0:2, 1], (np.size(pos_on_tplane_e[:, 0:2, 1])) + pos_on_tplane_e_y = array_ns.reshape( + pos_on_tplane_e[:, 0:2, 1], (array_ns.size(pos_on_tplane_e[:, 0:2, 1])) ) return pos_on_tplane_e_x, pos_on_tplane_e_y diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 4b45133d6..cd5404385 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -643,7 +643,6 @@ def _check_union( ) -> bool: members = get_args(union) # fix for unions with only one member, which implicitly are not Union but fallback to the type - # fix for unions with only one member, which implicitly are not Union but fallback to the type if not members: members = (union,) annotation = parameter_definition.annotation diff --git a/model/common/tests/interpolation_tests/test_interpolation_factory.py b/model/common/tests/interpolation_tests/test_interpolation_factory.py index b43c70ee9..02604491b 100644 --- a/model/common/tests/interpolation_tests/test_interpolation_factory.py +++ b/model/common/tests/interpolation_tests/test_interpolation_factory.py @@ -242,6 +242,63 @@ def test_get_mass_conserving_cell_average_weight( assert test_helpers.dallclose(field_ref.asnumpy(), field.asnumpy(), rtol=rtol) +## FIXME: does not validate +# -> connectivity order between reference from serialbox and computed value is different +@pytest.mark.parametrize( + "grid_file, experiment, rtol", + [ + (dt_utils.REGIONAL_EXPERIMENT, dt_utils.REGIONAL_EXPERIMENT, 5e-9), + ], +) +@pytest.mark.datatest +def test_e_flx_avg(interpolation_savepoint, grid_file, experiment, backend, rtol): + # TODO (any): This test does not work on gpu backend because the field operator is run with embedded backend + if data_alloc.is_cupy_device(backend): + pytest.skip("skipping: gpu backend is unsupported") + field_ref = interpolation_savepoint.e_flx_avg() + factory = get_interpolation_factory(backend, experiment, grid_file) + grid = factory.grid + field = factory.get(attrs.E_FLX_AVG) + assert field.shape == (grid.num_edges, grid.connectivities[dims.E2C2EODim].shape[1]) + # FIXME: e2c2e constructed from grid file has different ordering than the serialized one + assert_reordered(field.asnumpy(), field_ref.asnumpy(), rtol=5e-2) + + +@pytest.mark.parametrize( + "grid_file, experiment, rtol", + [ + (dt_utils.REGIONAL_EXPERIMENT, dt_utils.REGIONAL_EXPERIMENT, 5e-9), + (dt_utils.R02B04_GLOBAL, dt_utils.GLOBAL_EXPERIMENT, 1e-11), + ], +) +@pytest.mark.datatest +def test_e_bln_c_s(interpolation_savepoint, grid_file, experiment, backend, rtol): + field_ref = interpolation_savepoint.e_bln_c_s() + factory = get_interpolation_factory(backend, experiment, grid_file) + grid = factory.grid + field = factory.get(attrs.E_BLN_C_S) + assert field.shape == (grid.num_cells, C2E_SIZE) + assert test_helpers.dallclose(field_ref.asnumpy(), field.asnumpy(), rtol=rtol) + + +@pytest.mark.parametrize( + "grid_file, experiment, rtol", + [ + (dt_utils.REGIONAL_EXPERIMENT, dt_utils.REGIONAL_EXPERIMENT, 5e-9), + (dt_utils.R02B04_GLOBAL, dt_utils.GLOBAL_EXPERIMENT, 1e-11), + ], +) +@pytest.mark.datatest +def test_pos_on_tplane_e_x_y(interpolation_savepoint, grid_file, experiment, backend, rtol): + field_ref_1 = interpolation_savepoint.pos_on_tplane_e_x() + field_ref_2 = interpolation_savepoint.pos_on_tplane_e_y() + factory = get_interpolation_factory(backend, experiment, grid_file) + field_1 = factory.get(attrs.POS_ON_TPLANE_E_X) + field_2 = factory.get(attrs.POS_ON_TPLANE_E_Y) + assert test_helpers.dallclose(field_ref_1.asnumpy(), field_1.asnumpy(), rtol=rtol) + assert test_helpers.dallclose(field_ref_2.asnumpy(), field_2.asnumpy(), atol=1e-8) + + @pytest.mark.parametrize( "grid_file, experiment, rtol", [ diff --git a/model/common/tests/interpolation_tests/test_interpolation_fields.py b/model/common/tests/interpolation_tests/test_interpolation_fields.py index 83b377370..0118d1883 100644 --- a/model/common/tests/interpolation_tests/test_interpolation_fields.py +++ b/model/common/tests/interpolation_tests/test_interpolation_fields.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -import numpy as np import pytest import icon4py.model.common.dimension as dims @@ -256,9 +255,6 @@ def test_compute_e_flx_avg(grid_savepoint, interpolation_savepoint, icon_grid, b primal_cart_normal_x = grid_savepoint.primal_cart_normal_x().asnumpy() primal_cart_normal_y = grid_savepoint.primal_cart_normal_y().asnumpy() primal_cart_normal_z = grid_savepoint.primal_cart_normal_z().asnumpy() - primal_cart_normal = np.transpose( - np.stack((primal_cart_normal_x, primal_cart_normal_y, primal_cart_normal_z)) - ) e2c = data_alloc.as_numpy(icon_grid.connectivities[dims.E2CDim]) c2e = data_alloc.as_numpy(icon_grid.connectivities[dims.C2EDim]) c2e2c = data_alloc.as_numpy(icon_grid.connectivities[dims.C2E2CDim]) @@ -269,7 +265,9 @@ def test_compute_e_flx_avg(grid_savepoint, interpolation_savepoint, icon_grid, b c_bln_avg, geofac_div, owner_mask, - primal_cart_normal, + primal_cart_normal_x, + primal_cart_normal_y, + primal_cart_normal_z, e2c, c2e, c2e2c,