Skip to content

Commit

Permalink
Move precision definition to ndsl.dsl.__init__
Browse files Browse the repository at this point in the history
Detect bad order of imports, e.g., gt4py import before NDSL
Adapt code to use the newly defined NDSL_GLOBAL_PRECISION
Set internal GT4Py literal precision definition
  • Loading branch information
FlorianDeconinck committed Dec 24, 2024
1 parent 76f53c8 commit aee7ade
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 28 deletions.
1 change: 1 addition & 0 deletions ndsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import dsl # isort:skip
from .comm.communicator import CubedSphereCommunicator, TileCommunicator
from .comm.local_comm import LocalComm
from .comm.mpi import MPIComm
Expand Down
23 changes: 19 additions & 4 deletions ndsl/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import gt4py.cartesian.config
# Literal precision for both GT4Py & NDSL
import os
import sys

from ndsl.comm.mpi import MPI

gt4py_config_module = "gt4py.cartesian.config"
if gt4py_config_module in sys.modules:
raise RuntimeError(
"`GT4Py` config imported before `ndsl` imported."
" Please import `ndsl.dsl` or any `ndsl` module "
" before any `gt4py` imports."
)
NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64"))
os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION)


# Set cache names for default gt backends workflow
import gt4py.cartesian.config # noqa: E402

from ndsl.comm.mpi import MPI # noqa: E402


if MPI is not None:
Expand All @@ -9,5 +26,3 @@
gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get(
"GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}"
)

__version__ = "0.2.0"
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from ndsl.comm.communicator import Communicator
from ndsl.comm.partitioner import Partitioner
from ndsl.dsl import NDSL_GLOBAL_PRECISION
from ndsl.dsl.caches.cache_location import identify_code_path
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import floating_point_precision
from ndsl.optional_imports import cupy as cp


Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
)

if floating_point_precision() == 32:
if NDSL_GLOBAL_PRECISION == 32:
# When using 32-bit float, we flip the default dtypes to be all
# C, e.g. 32 bit.
dace.Config.set(
Expand Down
39 changes: 17 additions & 22 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Tuple, Union, cast
from typing import Tuple, TypeAlias, Union, cast

import gt4py.cartesian.gtscript as gtscript
import numpy as np

from ndsl.dsl import NDSL_GLOBAL_PRECISION


# A Field
Field = gtscript.Field
Expand All @@ -21,36 +22,30 @@
# Union of valid data types (from gt4py.cartesian.gtscript)
DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64]


def floating_point_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


# We redefine the type as a way to distinguish
# the model definition of a float to other usage of the
# common numpy type in the rest of the code.
NDSL_32BIT_FLOAT_TYPE = np.float32
NDSL_64BIT_FLOAT_TYPE = np.float64
NDSL_32BIT_FLOAT_TYPE: TypeAlias = np.float32
NDSL_32BIT_INT_TYPE: TypeAlias = np.int32
NDSL_64BIT_FLOAT_TYPE: TypeAlias = np.float64
NDSL_64BIT_INT_TYPE: TypeAlias = np.int64


def global_set_floating_point_precision():
def global_set_floating_point_precision() -> Tuple[TypeAlias, TypeAlias]:
"""Set the global floating point precision for all reference
to Float in the codebase. Defaults to 64 bit."""
global Float
precision_in_bit = floating_point_precision()
if precision_in_bit == 64:
return NDSL_64BIT_FLOAT_TYPE
elif precision_in_bit == 32:
return NDSL_32BIT_FLOAT_TYPE
else:
NotImplementedError(
f"{precision_in_bit} bit precision not implemented or tested"
)
global Float, Int
if NDSL_GLOBAL_PRECISION == 64:
return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE
elif NDSL_GLOBAL_PRECISION == 32:
return NDSL_32BIT_FLOAT_TYPE, NDSL_32BIT_INT_TYPE
raise NotImplementedError(
f"{NDSL_GLOBAL_PRECISION} bit precision not implemented or tested"
)


# Default float and int types
Float = global_set_floating_point_precision()
Int = np.int_
Float, Int = global_set_floating_point_precision()
Bool = np.bool_

FloatField = Field[gtscript.IJK, Float]
Expand Down

0 comments on commit aee7ade

Please sign in to comment.