Skip to content

Commit

Permalink
fix non conforming names in serialbox_utils.py (#518)
Browse files Browse the repository at this point in the history
use field type alias for prognostic_state.py
also fix imports in changed files
  • Loading branch information
halungge authored Aug 8, 2024
1 parent 0eda8a3 commit 0be1aba
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 78 deletions.
20 changes: 9 additions & 11 deletions model/common/src/icon4py/model/common/states/prognostic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from dataclasses import dataclass

from gt4py.next import as_field
from gt4py.next.common import Field

from icon4py.model.common import field_type_aliases as fa
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common import dimension as dims, field_type_aliases as fa, type_alias as ta


@dataclass
Expand All @@ -27,14 +25,14 @@ class PrognosticState:
Corresponds to ICON t_nh_prog
"""

rho: fa.CellKField[float] # density, rho(nproma, nlev, nblks_c) [kg/m^3]
w: fa.CellKField[float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[
[EdgeDim, KDim], float
rho: fa.CellKField[ta.wpfloat] # density, rho(nproma, nlev, nblks_c) [kg/m^3]
w: fa.CellKField[ta.wpfloat] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: fa.EdgeKField[
ta.wpfloat
] # horizontal wind normal to edges, vn(nproma, nlev, nblks_e) [m/s]
exner: fa.CellKField[float] # exner function, exner(nrpoma, nlev, nblks_c)
theta_v: fa.CellKField[float] # virtual temperature, (nproma, nlev, nlbks_c) [K]
exner: fa.CellKField[ta.wpfloat] # exner function, exner(nrpoma, nlev, nblks_c)
theta_v: fa.CellKField[ta.wpfloat] # virtual temperature, (nproma, nlev, nlbks_c) [K]

@property
def w_1(self) -> fa.CellField[float]:
return as_field((CellDim,), self.w.ndarray[:, 0])
def w_1(self) -> fa.CellField[ta.wpfloat]:
return as_field((dims.CellDim,), self.w.ndarray[:, 0])
Original file line number Diff line number Diff line change
Expand Up @@ -1346,10 +1346,10 @@ def pressure_ifc(self):
def pressure_sfc(self):
return self._get_field("output_diag_pressure_sfc", dims.CellDim)

def zonal_Wind(self):
def zonal_wind(self):
return self._get_field("output_diag_u", dims.CellDim, dims.KDim)

def meridional_Wind(self):
def meridional_wind(self):
return self._get_field("output_diag_v", dims.CellDim, dims.KDim)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,25 @@

import pytest

from icon4py.model.common.constants import CPD_O_RD, GRAV_O_RD, P0REF
from icon4py.model.common.diagnostic_calculations.stencils.diagnose_pressure import (
diagnose_pressure,
import icon4py.model.common.grid.horizontal as h_grid
from icon4py.model.common import constants as constants, dimension as dims
from icon4py.model.common.diagnostic_calculations.stencils import (
diagnose_pressure as pressure,
diagnose_surface_pressure as surface_pressure,
diagnose_temperature as temperature,
)
from icon4py.model.common.diagnostic_calculations.stencils.diagnose_surface_pressure import (
diagnose_surface_pressure,
from icon4py.model.common.interpolation.stencils import (
edge_2_cell_vector_rbf_interpolation as rbf,
)
from icon4py.model.common.diagnostic_calculations.stencils.diagnose_temperature import (
diagnose_temperature,
)
from icon4py.model.common.dimension import CellDim, KDim
from icon4py.model.common.grid.horizontal import HorizontalMarkerIndex
from icon4py.model.common.interpolation.stencils.edge_2_cell_vector_rbf_interpolation import (
edge_2_cell_vector_rbf_interpolation,
)
from icon4py.model.common.states.diagnostic_state import DiagnosticState
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.datatest_utils import JABW_EXPERIMENT
from icon4py.model.common.test_utils.helpers import dallclose, zero_field
from icon4py.model.common.states import diagnostic_state as diagnostics, prognostic_state
from icon4py.model.common.test_utils import datatest_utils as dt_utils, helpers as helpers


@pytest.mark.datatest
@pytest.mark.parametrize(
"experiment",
[
JABW_EXPERIMENT,
dt_utils.JABW_EXPERIMENT,
],
)
def test_diagnose_temperature(
Expand All @@ -49,33 +42,37 @@ def test_diagnose_temperature(
):
sp = data_provider.from_savepoint_jabw_final()
icon_diagnostics_output_sp = data_provider.from_savepoint_jabw_diagnostic()
prognostic_state_now = PrognosticState(
prognostic_state_now = prognostic_state.PrognosticState(
rho=sp.rho(),
w=None,
vn=sp.vn(),
exner=sp.exner(),
theta_v=sp.theta_v(),
)
diagnostic_state = DiagnosticState(
temperature=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure_ifc=zero_field(icon_grid, CellDim, KDim, dtype=float, extend={KDim: 1}),
u=zero_field(icon_grid, CellDim, KDim, dtype=float),
v=zero_field(icon_grid, CellDim, KDim, dtype=float),
diagnostic_state = diagnostics.DiagnosticState(
temperature=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
pressure=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
pressure_ifc=helpers.zero_field(
icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}
),
u=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
v=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
)

diagnose_temperature(
temperature.diagnose_temperature(
prognostic_state_now.theta_v,
prognostic_state_now.exner,
diagnostic_state.temperature,
icon_grid.get_start_index(CellDim, HorizontalMarkerIndex.interior(CellDim)),
icon_grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim)),
icon_grid.get_start_index(
dims.CellDim, h_grid.HorizontalMarkerIndex.interior(dims.CellDim)
),
icon_grid.get_end_index(dims.CellDim, h_grid.HorizontalMarkerIndex.end(dims.CellDim)),
0,
icon_grid.num_levels,
offset_provider={},
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.temperature.asnumpy(),
icon_diagnostics_output_sp.temperature().asnumpy(),
)
Expand All @@ -85,7 +82,7 @@ def test_diagnose_temperature(
@pytest.mark.parametrize(
"experiment",
[
JABW_EXPERIMENT,
dt_utils.JABW_EXPERIMENT,
],
)
def test_diagnose_meridional_and_zonal_winds(
Expand All @@ -95,28 +92,32 @@ def test_diagnose_meridional_and_zonal_winds(
):
sp = data_provider.from_savepoint_jabw_final()
icon_diagnostics_output_sp = data_provider.from_savepoint_jabw_diagnostic()
prognostic_state_now = PrognosticState(
prognostic_state_now = prognostic_state.PrognosticState(
rho=sp.rho(),
w=None,
vn=sp.vn(),
exner=sp.exner(),
theta_v=sp.theta_v(),
)
diagnostic_state = DiagnosticState(
temperature=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure_ifc=zero_field(icon_grid, CellDim, KDim, dtype=float, extend={KDim: 1}),
u=zero_field(icon_grid, CellDim, KDim, dtype=float),
v=zero_field(icon_grid, CellDim, KDim, dtype=float),
diagnostic_state = diagnostics.DiagnosticState(
temperature=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
pressure=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
pressure_ifc=helpers.zero_field(
icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}
),
u=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
v=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
)

rbv_vec_coeff_c1 = data_provider.from_interpolation_savepoint().rbf_vec_coeff_c1()
rbv_vec_coeff_c2 = data_provider.from_interpolation_savepoint().rbf_vec_coeff_c2()
grid_idx_cell_start_plus1 = icon_grid.get_end_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim) + 1
dims.CellDim, h_grid.HorizontalMarkerIndex.lateral_boundary(dims.CellDim) + 1
)
grid_idx_cell_end = icon_grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim))
edge_2_cell_vector_rbf_interpolation(
grid_idx_cell_end = icon_grid.get_end_index(
dims.CellDim, h_grid.HorizontalMarkerIndex.end(dims.CellDim)
)
rbf.edge_2_cell_vector_rbf_interpolation(
prognostic_state_now.vn,
rbv_vec_coeff_c1,
rbv_vec_coeff_c2,
Expand All @@ -131,14 +132,14 @@ def test_diagnose_meridional_and_zonal_winds(
},
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.u.asnumpy(),
icon_diagnostics_output_sp.zonal_Wind().asnumpy(),
icon_diagnostics_output_sp.zonal_wind().asnumpy(),
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.v.asnumpy(),
icon_diagnostics_output_sp.meridional_Wind().asnumpy(),
icon_diagnostics_output_sp.meridional_wind().asnumpy(),
atol=1.0e-13,
)

Expand All @@ -147,7 +148,7 @@ def test_diagnose_meridional_and_zonal_winds(
@pytest.mark.parametrize(
"experiment",
[
JABW_EXPERIMENT,
dt_utils.JABW_EXPERIMENT,
],
)
def test_diagnose_pressure(
Expand All @@ -157,63 +158,69 @@ def test_diagnose_pressure(
):
sp = data_provider.from_savepoint_jabw_final()
icon_diagnostics_output_sp = data_provider.from_savepoint_jabw_diagnostic()
prognostic_state_now = PrognosticState(
prognostic_state_now = prognostic_state.PrognosticState(
rho=sp.rho(),
w=None,
vn=sp.vn(),
exner=sp.exner(),
theta_v=sp.theta_v(),
)
diagnostic_state = DiagnosticState(
diagnostic_state = diagnostics.DiagnosticState(
temperature=sp.temperature(),
pressure=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure_ifc=zero_field(icon_grid, CellDim, KDim, dtype=float, extend={KDim: 1}),
u=zero_field(icon_grid, CellDim, KDim, dtype=float),
v=zero_field(icon_grid, CellDim, KDim, dtype=float),
pressure=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
pressure_ifc=helpers.zero_field(
icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}
),
u=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
v=helpers.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float),
)

diagnose_surface_pressure(
surface_pressure.diagnose_surface_pressure(
prognostic_state_now.exner,
diagnostic_state.temperature,
data_provider.from_metrics_savepoint().ddqz_z_full(),
diagnostic_state.pressure_ifc,
CPD_O_RD,
P0REF,
GRAV_O_RD,
constants.CPD_O_RD,
constants.P0REF,
constants.GRAV_O_RD,
horizontal_start=icon_grid.get_start_index(
CellDim, HorizontalMarkerIndex.interior(CellDim)
dims.CellDim, h_grid.HorizontalMarkerIndex.interior(dims.CellDim)
),
horizontal_end=icon_grid.get_end_index(
dims.CellDim, h_grid.HorizontalMarkerIndex.end(dims.CellDim)
),
horizontal_end=icon_grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim)),
vertical_start=icon_grid.num_levels,
vertical_end=icon_grid.num_levels + 1,
offset_provider={"Koff": KDim},
offset_provider={"Koff": dims.KDim},
)

diagnose_pressure(
pressure.diagnose_pressure(
data_provider.from_metrics_savepoint().ddqz_z_full(),
diagnostic_state.temperature,
diagnostic_state.pressure_sfc,
diagnostic_state.pressure,
diagnostic_state.pressure_ifc,
GRAV_O_RD,
icon_grid.get_start_index(CellDim, HorizontalMarkerIndex.interior(CellDim)),
icon_grid.get_end_index(CellDim, HorizontalMarkerIndex.end(CellDim)),
constants.GRAV_O_RD,
icon_grid.get_start_index(
dims.CellDim, h_grid.HorizontalMarkerIndex.interior(dims.CellDim)
),
icon_grid.get_end_index(dims.CellDim, h_grid.HorizontalMarkerIndex.end(dims.CellDim)),
0,
icon_grid.num_levels,
offset_provider={},
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.pressure_sfc.asnumpy(),
icon_diagnostics_output_sp.pressure_sfc().asnumpy(),
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.pressure_ifc.asnumpy(),
icon_diagnostics_output_sp.pressure_ifc().asnumpy(),
)

assert dallclose(
assert helpers.dallclose(
diagnostic_state.pressure.asnumpy(),
icon_diagnostics_output_sp.pressure().asnumpy(),
)

0 comments on commit 0be1aba

Please sign in to comment.