From 7dea36afead21fe68d25195c637ec69588bd3e9c Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:31:03 +0100 Subject: [PATCH 1/4] feat[next]: new domain slice syntax (#1453) New domain slice syntax, e.g. f[I(-1):I(5)] --- src/gt4py/next/common.py | 12 +++++- src/gt4py/next/embedded/common.py | 29 +++++++++++++ src/gt4py/next/embedded/nd_array_field.py | 1 + src/gt4py/next/iterator/tracing.py | 4 +- .../unit_tests/embedded_tests/test_common.py | 35 ++++++++++++++- .../embedded_tests/test_nd_array_field.py | 43 +++++++++++++++++++ 6 files changed, 120 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f4e35b5533..e55e13a38d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -75,6 +75,9 @@ class Dimension: def __str__(self): return f"{self.value}[{self.kind}]" + def __call__(self, val: int) -> NamedIndex: + return self, val + class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -272,7 +275,10 @@ def unit_range(r: RangeLike) -> UnitRange: NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType -AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +NamedSlice: TypeAlias = ( + slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ +) +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] RelativeIndexSequence: TypeAlias = tuple[ @@ -307,6 +313,10 @@ def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]: + return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop)) + + def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 94efe4d61d..f9201da247 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -146,3 +146,32 @@ def _find_index_of_dim( if dim == d: return i return None + + +def canonicalize_any_index_sequence( + index: common.AnyIndexSpec, +) -> common.AnyIndexSpec: + # TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice` + new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index + if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index): + new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement + return new_index + + +def _named_slice_to_named_range( + idx: common.NamedSlice, +) -> common.NamedRange | common.NamedSlice: + assert hasattr(idx, "start") and hasattr(idx, "stop") + if common.is_named_slice(idx): + idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined] + if idx_start_0 != idx_stop_0: + raise IndexError( + f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." + ) + assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) + return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) + if common.is_named_index(idx.start) and idx.stop is None: + raise IndexError(f"Upper bound needs to be specified for {idx}.") + if common.is_named_index(idx.stop) and idx.start is None: + raise IndexError(f"Lower bound needs to be specified for {idx}.") + return idx diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1bdb7161ec..38aab09df1 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -301,6 +301,7 @@ def __invert__(self) -> NdArrayField: def _slice( self, index: common.AnyIndexSpec ) -> tuple[common.Domain, common.RelativeIndexSequence]: + index = embedded_common.canonicalize_any_index_sequence(index) new_domain = embedded_common.sub_domain(self.domain, index) index_sequence = common.as_any_index_sequence(index) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 05ebd02352..1c9887cd22 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -147,6 +147,8 @@ def __call__(self, *args): def make_node(o): if isinstance(o, Node): return o + if isinstance(o, common.Dimension): + return AxisLiteral(value=o.value) if callable(o): if o.__name__ == "": return lambdadef(o) @@ -156,8 +158,6 @@ def make_node(o): return OffsetLiteral(value=o.value) if isinstance(o, core_defs.Scalar): return im.literal_from_value(o) - if isinstance(o, common.Dimension): - return AxisLiteral(value=o.value) if isinstance(o, tuple): return _f("make_tuple", *(make_node(arg) for arg in o)) if o is None: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index de511fdabb..91f15ee936 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,12 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain +from gt4py.next.embedded.common import ( + _slice_range, + canonicalize_any_index_sequence, + iterate_domain, + sub_domain, +) @pytest.mark.parametrize( @@ -147,3 +152,31 @@ def test_iterate_domain(): testee = list(iterate_domain(domain)) assert testee == ref + + +@pytest.mark.parametrize( + "slices, expected", + [ + [slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)], + [ + (slice(J(3), J(6)), slice(I(3), I(5))), + ((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))), + ], + [slice(I(1), J(7)), IndexError], + [ + slice(I(1), None), + IndexError, + ], + [ + slice(None, K(8)), + IndexError, + ], + ], +) +def test_slicing(slices, expected): + if expected is IndexError: + with pytest.raises(IndexError): + canonicalize_any_index_sequence(slices) + else: + testee = canonicalize_any_index_sequence(slices) + assert testee == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 66189cf9eb..49f74a566b 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -463,6 +463,49 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): assert indexed_field.domain.dims == expected_dimensions +def test_absolute_indexing_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)] + expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))] + + assert common.is_field(indexed_field_1) + assert indexed_field_1 == expected + + +def test_absolute_indexing_dim_sliced_single_slice(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + indexed_field_1 = field[KDim(11)] + indexed_field_2 = field[(KDim, 11)] + + assert common.is_field(indexed_field_1) + assert indexed_field_1 == indexed_field_2 + + +def test_absolute_indexing_wrong_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + + with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."): + field[JDim(8) : IDim(10)] + + +def test_absolute_indexing_empty_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + with pytest.raises(IndexError, match="Lower bound needs to be specified"): + field[: IDim(10)] + + def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) From 4228624d5e7818eecab1ec2b085118fb30fa97b0 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Mon, 26 Feb 2024 10:39:16 +0100 Subject: [PATCH 2/4] feat[next]: config improvements (#1461) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## New * feature: add gtfn backend factory param for using temporaries ## Changed * feature: make verbose exceptions a config option * refactor: reduce import time config dependencies * refactor: remove 'set_verbose_exceptions', use `config.VERBOSE_EXCEPTIONS = ...` instead * docs: document cmake build types enum * tests: update gtfn factory tests --------- Co-authored-by: Rico Häuselmann --- src/gt4py/next/config.py | 59 ++++++++++++------- src/gt4py/next/errors/__init__.py | 2 - src/gt4py/next/errors/excepthook.py | 33 +---------- .../next/iterator/transforms/pass_manager.py | 4 +- src/gt4py/next/otf/compilation/cache.py | 3 +- .../next/program_processors/runners/gtfn.py | 17 +++--- .../test_temporaries_with_sizes.py | 2 +- .../iterator_tests/test_vertical_advection.py | 5 +- tests/next_tests/unit_tests/conftest.py | 2 +- .../errors_tests/test_excepthook.py | 23 -------- .../runners_tests/test_gtfn.py | 23 +++++++- tests/next_tests/unit_tests/test_config.py | 52 ++++++++++++++++ 12 files changed, 130 insertions(+), 95 deletions(-) create mode 100644 tests/next_tests/unit_tests/test_config.py diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 74bf56d6e8..682d5254e5 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -26,22 +26,34 @@ class BuildCacheLifetime(enum.Enum): class CMakeBuildType(enum.Enum): - def _generate_next_value_(name, start, count, last_values): - return "".join(part.capitalize() for part in name.split("_")) + """ + CMake build types enum. - DEBUG = enum.auto() - RELEASE = enum.auto() - REL_WITH_DEB_INFO = enum.auto() - MIN_SIZE_REL = enum.auto() + Member values have to be valid CMake syntax. + """ + DEBUG = "Debug" + RELEASE = "Release" + REL_WITH_DEB_INFO = "RelWithDebInfo" + MIN_SIZE_REL = "MinSizeRel" -def env_flag_to_bool(flag_value: str) -> bool: - """Like in gt4py.cartesian, env vars for flags should be set to '0' or '1'.""" + +def env_flag_to_bool(name: str, default: bool) -> bool: + """Recognize true or false signaling string values.""" + flag_value = None + if name in os.environ: + flag_value = os.environ[name].lower() match flag_value: - case "0" | "1": - return bool(int(flag_value)) + case None: + return default + case "0" | "false" | "off": + return False + case "1" | "true" | "on": + return True case _: - raise ValueError("GT4Py flag environment variables must have value '0' or '1'.") + raise ValueError( + "Invalid GT4Py environment flag value: use '0 | false | off' or '1 | true | on'." + ) _PREFIX: Final[str] = "GT4PY" @@ -49,11 +61,18 @@ def env_flag_to_bool(flag_value: str) -> bool: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(os.environ.get(f"{_PREFIX}_DEBUG", "0")) +DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) + + +#: Verbose flag for DSL compilation errors +VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( + f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False +) + #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT -BUILD_CACHE_DIR: Final[pathlib.Path] = ( +BUILD_CACHE_DIR: pathlib.Path = ( pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) / "gt4py_cache" ) @@ -62,14 +81,12 @@ def env_flag_to_bool(flag_value: str) -> bool: #: Whether generated code projects should be kept around between runs. #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs -BUILD_CACHE_LIFETIME: Final[BuildCacheLifetime] = getattr( - BuildCacheLifetime, - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper(), -) +BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ + os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() +] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. -CMAKE_BUILD_TYPE: Final[CMakeBuildType] = getattr( - CMakeBuildType, - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper(), -) +CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ + os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() +] diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index dd48d6f0f9..c965332929 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -17,7 +17,6 @@ from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. excepthook, ) -from .excepthook import set_verbose_exceptions from .exceptions import ( DSLError, InvalidParameterAnnotationError, @@ -37,5 +36,4 @@ "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", - "set_verbose_exceptions", ] diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index f1dd18e1b4..673eaca757 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -20,39 +20,12 @@ other errors. """ -import os import sys -import warnings from typing import Callable -from . import exceptions, formatting - - -def _get_verbose_exceptions_envvar() -> bool: - """Detect if the user enabled verbose exceptions in the environment variables.""" - env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" - if env_var_name in os.environ: - false_values = ["0", "false", "off"] - true_values = ["1", "true", "on"] - value = os.environ[env_var_name].lower() - if value in false_values: - return False - elif value in true_values: - return True - else: - values = ", ".join([*false_values, *true_values]) - msg = f"the 'GT4PY_VERBOSE_EXCEPTIONS' environment variable must be one of {values} (case insensitive)" - warnings.warn(msg) - return False - +from gt4py.next import config -_verbose_exceptions: bool = _get_verbose_exceptions_envvar() - - -def set_verbose_exceptions(enabled: bool = False) -> None: - """Programmatically set whether to use verbose printing for uncaught errors.""" - global _verbose_exceptions - _verbose_exceptions = enabled +from . import exceptions, formatting def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -> list[str]: @@ -77,7 +50,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException also printed. """ if isinstance(value, exceptions.DSLError): - exc_strs = _format_uncaught_error(value, _verbose_exceptions) + exc_strs = _format_uncaught_error(value, config.VERBOSE_EXCEPTIONS) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index cd8ebb5516..e8f836ddaf 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -38,7 +38,7 @@ @enum.unique class LiftMode(enum.Enum): FORCE_INLINE = enum.auto() - FORCE_TEMPORARIES = enum.auto() + USE_TEMPORARIES = enum.auto() SIMPLE_HEURISTIC = enum.auto() @@ -47,7 +47,7 @@ def _inline_lifts(ir, lift_mode): return InlineLifts().visit(ir) elif lift_mode == LiftMode.SIMPLE_HEURISTIC: return InlineLifts(simple_inline_heuristic.is_eligible_for_inlining).visit(ir) - elif lift_mode == LiftMode.FORCE_TEMPORARIES: + elif lift_mode == LiftMode.USE_TEMPORARIES: return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index ee5ec650e0..810952d0ef 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -27,7 +27,6 @@ _session_cache_dir = tempfile.TemporaryDirectory(prefix="gt4py_session_") _session_cache_dir_path = pathlib.Path(_session_cache_dir.name) -_persistent_cache_dir_path = config.BUILD_CACHE_DIR def _serialize_param(parameter: interface.Parameter) -> str: @@ -72,7 +71,7 @@ def get_cache_folder( case config.BuildCacheLifetime.SESSION: base_path = _session_cache_dir_path case config.BuildCacheLifetime.PERSISTENT: - base_path = _persistent_cache_dir_path + base_path = config.BUILD_CACHE_DIR case _: raise ValueError("Unsupported caching lifetime.") diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 04af4a5283..4a65f6d049 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -23,7 +23,8 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import common, config -from gt4py.next.iterator.transforms import LiftMode, global_tmps +from gt4py.next.iterator import transforms +from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -150,6 +151,7 @@ class Meta: class Params: name_device = "cpu" name_cached = "" + name_temps = "" name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), @@ -165,13 +167,18 @@ class Params: ), name_cached="_cached", ) + use_temporaries = factory.Trait( + otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, + otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, + name_temps="_with_temporaries", + ) device_type = core_defs.DeviceType.CPU hash_function = compilation_hash otf_workflow = factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) name = factory.LazyAttribute( - lambda o: f"run_gtfn_{o.name_device}{o.name_cached}{o.name_postfix}" + lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) executor = factory.LazyAttribute( @@ -189,11 +196,7 @@ class Params: run_gtfn_cached = GTFNBackendFactory(cached=True) -run_gtfn_with_temporaries = GTFNBackendFactory( - name_postfix="_with_temporaries", - otf_workflow__translation__lift_mode=LiftMode.FORCE_TEMPORARIES, - otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, -) +run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 13d8f7711e..a264843f49 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -111,7 +111,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): itir_with_tmp = apply_common_transforms( testee.itir, - lift_mode=LiftMode.FORCE_TEMPORARIES, + lift_mode=LiftMode.USE_TEMPORARIES, offset_provider=mesh_descriptor.offset_provider, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index f2a6505a7e..b28c98ab38 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -127,10 +127,7 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): and lift_mode == LiftMode.FORCE_INLINE ): pytest.skip("gtfn does only support lifted scans when using temporaries") - if ( - program_processor == gtfn.run_gtfn_with_temporaries - or lift_mode == LiftMode.FORCE_TEMPORARIES - ): + if program_processor == gtfn.run_gtfn_with_temporaries or lift_mode == LiftMode.USE_TEMPORARIES: pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index d3f9bdb761..17418a9ca6 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -38,7 +38,7 @@ @pytest.fixture( params=[ transforms.LiftMode.FORCE_INLINE, - transforms.LiftMode.FORCE_TEMPORARIES, + transforms.LiftMode.USE_TEMPORARIES, transforms.LiftMode.SIMPLE_HEURISTIC, ], ids=lambda p: f"lift_mode={p.name}", diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py index 526844d730..d9f29b99d5 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -12,11 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os - from gt4py import eve from gt4py.next import errors -from gt4py.next.errors import excepthook def test_format_uncaught_error(): @@ -35,23 +32,3 @@ def test_format_uncaught_error(): assert str_usermode.find("Traceback") < 0 assert str_usermode.find("cause") < 0 assert str_usermode.find("ValueError") < 0 - - -def test_get_verbose_exceptions(): - env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" - - # Make sure to save and restore the environment variable, we don't want to - # affect other tests running in the same process. - saved = os.environ.get(env_var_name, None) - try: - os.environ[env_var_name] = "False" - assert excepthook._get_verbose_exceptions_envvar() is False - os.environ[env_var_name] = "True" - assert excepthook._get_verbose_exceptions_envvar() is True - os.environ[env_var_name] = "invalid value" # Should emit a warning too - assert excepthook._get_verbose_exceptions_envvar() is False - del os.environ[env_var_name] - assert excepthook._get_verbose_exceptions_envvar() is False - finally: - if saved is not None: - os.environ[env_var_name] = saved diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 897ddcfc08..755cddcf5a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -26,11 +26,13 @@ import gt4py._core.definitions as core_defs from gt4py.next import allocators, config +from gt4py.next.iterator import transforms +from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import workflow from gt4py.next.program_processors.runners import gtfn -def test_backend_factory_set_device(): +def test_backend_factory_trait_device(): cpu_version = gtfn.GTFNBackendFactory(gpu=False, cached=False) gpu_version = gtfn.GTFNBackendFactory(gpu=True, cached=False) @@ -51,12 +53,29 @@ def test_backend_factory_set_device(): assert allocators.is_field_allocator_for(gpu_version.allocator, core_defs.DeviceType.CUDA) -def test_backend_factory_set_cached(): +def test_backend_factory_trait_cached(): cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) assert isinstance(cached_version.executor.otf_workflow, workflow.CachedStep) assert cached_version.executor.__name__ == "run_gtfn_cpu_cached" +def test_backend_factory_trait_temporaries(): + inline_version = gtfn.GTFNBackendFactory(cached=False) + temps_version = gtfn.GTFNBackendFactory(cached=False, use_temporaries=True) + + assert inline_version.executor.otf_workflow.translation.lift_mode is None + assert ( + temps_version.executor.otf_workflow.translation.lift_mode + is transforms.LiftMode.USE_TEMPORARIES + ) + + assert inline_version.executor.otf_workflow.translation.temporary_extraction_heuristics is None + assert ( + temps_version.executor.otf_workflow.translation.temporary_extraction_heuristics + is global_tmps.SimpleTemporaryExtractionHeuristics + ) + + def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) session_version = gtfn.GTFNBackendFactory() diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py new file mode 100644 index 0000000000..0b906207c3 --- /dev/null +++ b/tests/next_tests/unit_tests/test_config.py @@ -0,0 +1,52 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os + +import pytest + +from gt4py.next import config + + +@pytest.fixture +def env_var(): + """Just in case another test will ever use that environment variable.""" + env_var_name = "GT4PY_TEST_ENV_VAR" + saved = os.environ.get(env_var_name, None) + yield env_var_name + if saved is not None: + os.environ[env_var_name] = saved + + +@pytest.mark.parametrize("value", ["False", "false", "0", "off"]) +def test_env_flag_to_bool_false(env_var, value): + os.environ[env_var] = value + assert config.env_flag_to_bool(env_var, default=True) is False + + +@pytest.mark.parametrize("value", ["True", "true", "1", "on"]) +def test_env_flag_to_bool_true(env_var, value): + os.environ[env_var] = value + assert config.env_flag_to_bool(env_var, default=False) is True + + +def test_env_flag_to_bool_invalid(env_var): + os.environ[env_var] = "invalid value" + with pytest.raises(ValueError): + config.env_flag_to_bool(env_var, default=False) + + +def test_env_flag_to_bool_unset(env_var): + del os.environ[env_var] + assert config.env_flag_to_bool(env_var, default=False) is False From b86a34780eb634e548c2c3a86696e8e61ab410f7 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 26 Feb 2024 15:17:46 +0100 Subject: [PATCH 3/4] fix[next][dace]: Fix translation of if statement from tasklet to inter-state condition (#1469) The bug addressed by this PR is that if-nodes were translated to tasklets. Tasklets assume that all inputs are evaluated. For if-nodes, we need to enforce exclusive execution of one of the two branches. That means that only one of the two arguments will be evaluated at runtime. We achieve this by implementing the true/false branches as separate states and checking the if-statement as condition on the inter-state edge. --- .../runners/dace_iterator/itir_to_tasklet.py | 144 ++++++++++++++---- 1 file changed, 114 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index cf6d7ab047..6ab371bf2b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import copy import dataclasses import itertools from collections.abc import Sequence @@ -566,37 +567,120 @@ def builtin_can_deref( def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - di = dace_debuginfo(node, transformer.context.body.debuginfo) - args = transformer.visit(node_args) - assert len(args) == 3 - if_node = args[0][0] if isinstance(args[0], list) else args[0] - - # the argument could be a list of elements on each branch representing the result of `make_tuple` - # however, the normal case is to find one value expression - assert len(args[1]) == len(args[2]) - if_expr_args = [ - (a[0] if isinstance(a, list) else a, b[0] if isinstance(b, list) else b) - for a, b in zip(args[1], args[2]) - ] - - # in case of tuple arguments, generate one if-tasklet for each element of the output tuple - if_expr_values = [] - for a, b in if_expr_args: - assert a.dtype == b.dtype - expr_args = [ - (arg, f"{arg.value.data}_v") - for arg in (if_node, a, b) - if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" - for arg in (if_node, a, b) - ] - expr = "({1} if {0} else {2})".format(*internals) - if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if", dace_debuginfo=di) - if_expr_values.append(if_expr[0]) + assert len(node_args) == 3 + sdfg = transformer.context.body + current_state = transformer.context.state + is_start_state = sdfg.start_block == current_state + + # build an empty state to join true and false branches + join_state = sdfg.add_state_before(current_state, "join") + + def build_if_state(arg, state): + symbol_map = copy.deepcopy(transformer.context.symbol_map) + node_context = Context(sdfg, state, symbol_map) + node_taskgen = PythonTaskletCodegen( + transformer.offset_provider, node_context, transformer.node_types + ) + return node_taskgen.visit(arg) + + # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state + stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) + stmt_node = build_if_state(node_args[0], stmt_state)[0] + assert isinstance(stmt_node, ValueExpr) + assert stmt_node.dtype == dace.dtypes.bool + assert sdfg.arrays[stmt_node.value.data].shape == (1,) + + # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state + tbr_state = sdfg.add_state("true_branch") + sdfg.add_edge( + stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") + ) + sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) + tbr_values = build_if_state(node_args[1], tbr_state) + # + fbr_state = sdfg.add_state("false_branch") + sdfg.add_edge( + stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") + ) + sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) + fbr_values = build_if_state(node_args[2], fbr_state) + + assert isinstance(stmt_node, ValueExpr) + assert stmt_node.dtype == dace.dtypes.bool + # make the result of the if-statement evaluation available inside current state + ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) + + # we distinguish between select if-statements, where both true and false branches are symbolic expressions, + # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch + # is a value expression, which has to be evaluated at runtime with conditional state transition + result_values = [] + assert len(tbr_values) == len(fbr_values) + for tbr_value, fbr_value in zip(tbr_values, fbr_values): + assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) + assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) + assert tbr_value.dtype == fbr_value.dtype + + if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): + # both branches return symbolic expressions, therefore the if-node can be translated + # to a select-tasklet inside current state + # TODO: use select-memlet when it becomes available in dace + code = f"{tbr_value.value} if _cond else {fbr_value.value}" + if_expr = transformer.add_expr_tasklet( + [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" + )[0] + result_values.append(if_expr) + else: + # at least one of the two branches contains a value expression, which should be evaluated + # only if the corresponding true/false condition is satisfied + desc = sdfg.arrays[ + tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data + ] + var = unique_var_name() + if isinstance(desc, dace.data.Scalar): + sdfg.add_scalar(var, desc.dtype, transient=True) + else: + sdfg.add_array(var, desc.shape, desc.dtype, transient=True) + + # write result to transient data container and access it in the original state + for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: + val_node = state.add_access(var) + if isinstance(expr, ValueExpr): + state.add_nedge( + expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) + ) + else: + assert desc.shape == (1,) + state.add_edge( + state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), + "_out", + val_node, + None, + dace.Memlet(var, "0"), + ) + result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) + + if tbr_state.is_empty() and fbr_state.is_empty(): + # if all branches are symbolic expressions, the true/false and join states can be removed + # as well as the conditional state transition + sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) + sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) + elif tbr_state.is_empty(): + # use direct edge from if-statement to join state for true branch + tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition + sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition + sdfg.remove_node(tbr_state) + elif fbr_state.is_empty(): + # use direct edge from if-statement to join state for false branch + fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition + sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition + sdfg.remove_node(fbr_state) + else: + # remove direct edge from if-statement to join state + sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) + # the if-statement condition is not used in current state + current_state.remove_node(ctx_stmt_node.value) - return if_expr_values + return result_values def builtin_list_get( From d9004cd0aa106cfa268247855ec8fdb653bad797 Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Mon, 26 Feb 2024 16:32:28 +0100 Subject: [PATCH 4/4] feat[cartesian]: Setting extra compile args and DaCe block size via env variables (#1462) * Set default block size using env variable. * Fix nvcc flags. * Set extra compile args using env variable. * Rename env variable. * Rename env variables. * Run pre-commit. * Use typing.List for compatibility with py3.8. * Fix dace compilation. * Address review comments. * Shift statement outside if block. --- src/gt4py/cartesian/backend/dace_backend.py | 5 +++++ src/gt4py/cartesian/caching.py | 2 ++ src/gt4py/cartesian/config.py | 13 ++++++++----- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b02c765ad7..0bfdec791f 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -544,9 +544,14 @@ def keep_line(line): def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.SDFG): self = cls() with dace.config.temporary_config(): + # To prevent conflict with 3rd party usage of DaCe config always make sure that any + # changes be under the temporary_config manager if gt_config.GT4PY_USE_HIP: dace.config.Config.set("compiler", "cuda", "backend", value="hip") dace.config.Config.set("compiler", "cuda", "max_concurrent_streams", value=-1) + dace.config.Config.set( + "compiler", "cuda", "default_block_size", value=gt_config.DACE_DEFAULT_BLOCK_SIZE + ) dace.config.Config.set("compiler", "cpu", "openmp_sections", value=False) code_objects = sdfg.generate_code() is_gpu = "CUDA" in {co.title for co in code_objects} diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index 421a37b271..4d716a6c79 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -312,6 +312,8 @@ def stencil_id(self) -> StencilID: "api_annotations": f"[{', '.join(self._extract_api_annotations())}]", **self._externals, } + if self.builder.backend.name == "dace:gpu": + fingerprint["default_block_size"] = gt_config.DACE_DEFAULT_BLOCK_SIZE # typeignore because attrclass StencilID has generated constructor return StencilID( # type: ignore diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index 23bf36de6c..ad031b80c2 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -14,7 +14,7 @@ import multiprocessing import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import gridtools_cpp @@ -49,6 +49,10 @@ GT_CPP_TEMPLATE_DEPTH: int = 1024 # Settings dict +GT4PY_EXTRA_COMPILE_ARGS: str = os.environ.get("GT4PY_EXTRA_COMPILE_ARGS", "") +extra_compile_args: List[str] = ( + list(GT4PY_EXTRA_COMPILE_ARGS.split(" ")) if GT4PY_EXTRA_COMPILE_ARGS else [] +) build_settings: Dict[str, Any] = { "boost_include_path": os.path.join(BOOST_ROOT, "include"), "cuda_bin_path": os.path.join(CUDA_ROOT, "bin"), @@ -57,10 +61,7 @@ "gt_include_path": os.environ.get("GT_INCLUDE_PATH", GT_INCLUDE_PATH), "openmp_cppflags": os.environ.get("OPENMP_CPPFLAGS", "-fopenmp").split(), "openmp_ldflags": os.environ.get("OPENMP_LDFLAGS", "-fopenmp").split(), - "extra_compile_args": { - "cxx": [], - "cuda": [], - }, + "extra_compile_args": {"cxx": extra_compile_args, "cuda": extra_compile_args}, "extra_link_args": [], "parallel_jobs": multiprocessing.cpu_count(), "cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH), @@ -83,3 +84,5 @@ code_settings: Dict[str, Any] = {"root_package_name": "_GT_"} os.environ.setdefault("DACE_CONFIG", os.path.join(os.path.abspath("."), ".dace.conf")) + +DACE_DEFAULT_BLOCK_SIZE: str = os.environ.get("DACE_DEFAULT_BLOCK_SIZE", "64,8,1")