From aee7ade6ec04a60a429393671217108700e67c06 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Dec 2024 11:56:23 -0500 Subject: [PATCH] Move precision definition to `ndsl.dsl.__init__` Detect bad order of imports, e.g., gt4py import before NDSL Adapt code to use the newly defined NDSL_GLOBAL_PRECISION Set internal GT4Py literal precision definition --- ndsl/__init__.py | 1 + ndsl/dsl/__init__.py | 23 +++++++++++++++++---- ndsl/dsl/dace/dace_config.py | 4 ++-- ndsl/dsl/typing.py | 39 ++++++++++++++++-------------------- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index a2f771cd..015024a0 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,3 +1,4 @@ +from . import dsl # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index ed44420a..62db51a8 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -1,6 +1,23 @@ -import gt4py.cartesian.config +# Literal precision for both GT4Py & NDSL +import os +import sys -from ndsl.comm.mpi import MPI + +gt4py_config_module = "gt4py.cartesian.config" +if gt4py_config_module in sys.modules: + raise RuntimeError( + "`GT4Py` config imported before `ndsl` imported." + " Please import `ndsl.dsl` or any `ndsl` module " + " before any `gt4py` imports." + ) +NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64")) +os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION) + + +# Set cache names for default gt backends workflow +import gt4py.cartesian.config # noqa: E402 + +from ndsl.comm.mpi import MPI # noqa: E402 if MPI is not None: @@ -9,5 +26,3 @@ gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get( "GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}" ) - -__version__ = "0.2.0" diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 7f1c1477..b7f03b0a 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -8,10 +8,10 @@ from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner +from ndsl.dsl import NDSL_GLOBAL_PRECISION from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend -from ndsl.dsl.typing import floating_point_precision from ndsl.optional_imports import cupy as cp @@ -264,7 +264,7 @@ def __init__( "compiler", "cuda", "syncdebug", value=dace_debug_env_var ) - if floating_point_precision() == 32: + if NDSL_GLOBAL_PRECISION == 32: # When using 32-bit float, we flip the default dtypes to be all # C, e.g. 32 bit. dace.Config.set( diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 3b6ba44c..568c699a 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -1,9 +1,10 @@ -import os -from typing import Tuple, Union, cast +from typing import Tuple, TypeAlias, Union, cast import gt4py.cartesian.gtscript as gtscript import numpy as np +from ndsl.dsl import NDSL_GLOBAL_PRECISION + # A Field Field = gtscript.Field @@ -21,36 +22,30 @@ # Union of valid data types (from gt4py.cartesian.gtscript) DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64] - -def floating_point_precision() -> int: - return int(os.getenv("PACE_FLOAT_PRECISION", "64")) - - # We redefine the type as a way to distinguish # the model definition of a float to other usage of the # common numpy type in the rest of the code. -NDSL_32BIT_FLOAT_TYPE = np.float32 -NDSL_64BIT_FLOAT_TYPE = np.float64 +NDSL_32BIT_FLOAT_TYPE: TypeAlias = np.float32 +NDSL_32BIT_INT_TYPE: TypeAlias = np.int32 +NDSL_64BIT_FLOAT_TYPE: TypeAlias = np.float64 +NDSL_64BIT_INT_TYPE: TypeAlias = np.int64 -def global_set_floating_point_precision(): +def global_set_floating_point_precision() -> Tuple[TypeAlias, TypeAlias]: """Set the global floating point precision for all reference to Float in the codebase. Defaults to 64 bit.""" - global Float - precision_in_bit = floating_point_precision() - if precision_in_bit == 64: - return NDSL_64BIT_FLOAT_TYPE - elif precision_in_bit == 32: - return NDSL_32BIT_FLOAT_TYPE - else: - NotImplementedError( - f"{precision_in_bit} bit precision not implemented or tested" - ) + global Float, Int + if NDSL_GLOBAL_PRECISION == 64: + return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE + elif NDSL_GLOBAL_PRECISION == 32: + return NDSL_32BIT_FLOAT_TYPE, NDSL_32BIT_INT_TYPE + raise NotImplementedError( + f"{NDSL_GLOBAL_PRECISION} bit precision not implemented or tested" + ) # Default float and int types -Float = global_set_floating_point_precision() -Int = np.int_ +Float, Int = global_set_floating_point_precision() Bool = np.bool_ FloatField = Field[gtscript.IJK, Float]