Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix_reverse_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Feb 26, 2024
2 parents ddcc272 + d9004cd commit 5c81c03
Show file tree
Hide file tree
Showing 22 changed files with 379 additions and 134 deletions.
5 changes: 5 additions & 0 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,14 @@ def keep_line(line):
def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.SDFG):
self = cls()
with dace.config.temporary_config():
# To prevent conflict with 3rd party usage of DaCe config always make sure that any
# changes be under the temporary_config manager
if gt_config.GT4PY_USE_HIP:
dace.config.Config.set("compiler", "cuda", "backend", value="hip")
dace.config.Config.set("compiler", "cuda", "max_concurrent_streams", value=-1)
dace.config.Config.set(
"compiler", "cuda", "default_block_size", value=gt_config.DACE_DEFAULT_BLOCK_SIZE
)
dace.config.Config.set("compiler", "cpu", "openmp_sections", value=False)
code_objects = sdfg.generate_code()
is_gpu = "CUDA" in {co.title for co in code_objects}
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/cartesian/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def stencil_id(self) -> StencilID:
"api_annotations": f"[{', '.join(self._extract_api_annotations())}]",
**self._externals,
}
if self.builder.backend.name == "dace:gpu":
fingerprint["default_block_size"] = gt_config.DACE_DEFAULT_BLOCK_SIZE

# typeignore because attrclass StencilID has generated constructor
return StencilID( # type: ignore
Expand Down
13 changes: 8 additions & 5 deletions src/gt4py/cartesian/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import multiprocessing
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import gridtools_cpp

Expand Down Expand Up @@ -49,6 +49,10 @@
GT_CPP_TEMPLATE_DEPTH: int = 1024

# Settings dict
GT4PY_EXTRA_COMPILE_ARGS: str = os.environ.get("GT4PY_EXTRA_COMPILE_ARGS", "")
extra_compile_args: List[str] = (
list(GT4PY_EXTRA_COMPILE_ARGS.split(" ")) if GT4PY_EXTRA_COMPILE_ARGS else []
)
build_settings: Dict[str, Any] = {
"boost_include_path": os.path.join(BOOST_ROOT, "include"),
"cuda_bin_path": os.path.join(CUDA_ROOT, "bin"),
Expand All @@ -57,10 +61,7 @@
"gt_include_path": os.environ.get("GT_INCLUDE_PATH", GT_INCLUDE_PATH),
"openmp_cppflags": os.environ.get("OPENMP_CPPFLAGS", "-fopenmp").split(),
"openmp_ldflags": os.environ.get("OPENMP_LDFLAGS", "-fopenmp").split(),
"extra_compile_args": {
"cxx": [],
"cuda": [],
},
"extra_compile_args": {"cxx": extra_compile_args, "cuda": extra_compile_args},
"extra_link_args": [],
"parallel_jobs": multiprocessing.cpu_count(),
"cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH),
Expand All @@ -83,3 +84,5 @@
code_settings: Dict[str, Any] = {"root_package_name": "_GT_"}

os.environ.setdefault("DACE_CONFIG", os.path.join(os.path.abspath("."), ".dace.conf"))

DACE_DEFAULT_BLOCK_SIZE: str = os.environ.get("DACE_DEFAULT_BLOCK_SIZE", "64,8,1")
12 changes: 11 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Dimension:
def __str__(self):
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
return self, val


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -272,7 +275,10 @@ def unit_range(r: RangeLike) -> UnitRange:
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
NamedSlice: TypeAlias = (
slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
)
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex]
RelativeIndexSequence: TypeAlias = tuple[
Expand Down Expand Up @@ -307,6 +313,10 @@ def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
)


def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]:
return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop))


def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
Expand Down
59 changes: 38 additions & 21 deletions src/gt4py/next/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,53 @@ class BuildCacheLifetime(enum.Enum):


class CMakeBuildType(enum.Enum):
def _generate_next_value_(name, start, count, last_values):
return "".join(part.capitalize() for part in name.split("_"))
"""
CMake build types enum.
DEBUG = enum.auto()
RELEASE = enum.auto()
REL_WITH_DEB_INFO = enum.auto()
MIN_SIZE_REL = enum.auto()
Member values have to be valid CMake syntax.
"""

DEBUG = "Debug"
RELEASE = "Release"
REL_WITH_DEB_INFO = "RelWithDebInfo"
MIN_SIZE_REL = "MinSizeRel"

def env_flag_to_bool(flag_value: str) -> bool:
"""Like in gt4py.cartesian, env vars for flags should be set to '0' or '1'."""

def env_flag_to_bool(name: str, default: bool) -> bool:
"""Recognize true or false signaling string values."""
flag_value = None
if name in os.environ:
flag_value = os.environ[name].lower()
match flag_value:
case "0" | "1":
return bool(int(flag_value))
case None:
return default
case "0" | "false" | "off":
return False
case "1" | "true" | "on":
return True
case _:
raise ValueError("GT4Py flag environment variables must have value '0' or '1'.")
raise ValueError(
"Invalid GT4Py environment flag value: use '0 | false | off' or '1 | true | on'."
)


_PREFIX: Final[str] = "GT4PY"

#: Master debug flag
#: Changes defaults for all the other options to be as helpful for debugging as possible.
#: Does not override values set in environment variables.
DEBUG: Final[bool] = env_flag_to_bool(os.environ.get(f"{_PREFIX}_DEBUG", "0"))
DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False)


#: Verbose flag for DSL compilation errors
VERBOSE_EXCEPTIONS: bool = env_flag_to_bool(
f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False
)


#: Where generated code projects should be persisted.
#: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT
BUILD_CACHE_DIR: Final[pathlib.Path] = (
BUILD_CACHE_DIR: pathlib.Path = (
pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir()))
/ "gt4py_cache"
)
Expand All @@ -62,14 +81,12 @@ def env_flag_to_bool(flag_value: str) -> bool:
#: Whether generated code projects should be kept around between runs.
#: - SESSION: generated code projects get destroyed when the interpreter shuts down
#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs
BUILD_CACHE_LIFETIME: Final[BuildCacheLifetime] = getattr(
BuildCacheLifetime,
os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper(),
)
BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[
os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper()
]

#: Build type to be used when CMake is used to compile generated code.
#: Might have no effect when CMake is not used as part of the toolchain.
CMAKE_BUILD_TYPE: Final[CMakeBuildType] = getattr(
CMakeBuildType,
os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper(),
)
CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[
os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper()
]
29 changes: 29 additions & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,32 @@ def _find_index_of_dim(
if dim == d:
return i
return None


def canonicalize_any_index_sequence(
index: common.AnyIndexSpec,
) -> common.AnyIndexSpec:
# TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice`
new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index
if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index):
new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement
return new_index


def _named_slice_to_named_range(
idx: common.NamedSlice,
) -> common.NamedRange | common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined]
if idx_start_0 != idx_stop_0:
raise IndexError(
f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'."
)
assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int)
return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1))
if common.is_named_index(idx.start) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
if common.is_named_index(idx.stop) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
return idx
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __invert__(self) -> NdArrayField:
def _slice(
self, index: common.AnyIndexSpec
) -> tuple[common.Domain, common.RelativeIndexSequence]:
index = embedded_common.canonicalize_any_index_sequence(index)
new_domain = embedded_common.sub_domain(self.domain, index)

index_sequence = common.as_any_index_sequence(index)
Expand Down
2 changes: 0 additions & 2 deletions src/gt4py/next/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions.
excepthook,
)
from .excepthook import set_verbose_exceptions
from .exceptions import (
DSLError,
InvalidParameterAnnotationError,
Expand All @@ -37,5 +36,4 @@
"MissingArgumentError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"set_verbose_exceptions",
]
33 changes: 3 additions & 30 deletions src/gt4py/next/errors/excepthook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,12 @@
other errors.
"""

import os
import sys
import warnings
from typing import Callable

from . import exceptions, formatting


def _get_verbose_exceptions_envvar() -> bool:
"""Detect if the user enabled verbose exceptions in the environment variables."""
env_var_name = "GT4PY_VERBOSE_EXCEPTIONS"
if env_var_name in os.environ:
false_values = ["0", "false", "off"]
true_values = ["1", "true", "on"]
value = os.environ[env_var_name].lower()
if value in false_values:
return False
elif value in true_values:
return True
else:
values = ", ".join([*false_values, *true_values])
msg = f"the 'GT4PY_VERBOSE_EXCEPTIONS' environment variable must be one of {values} (case insensitive)"
warnings.warn(msg)
return False

from gt4py.next import config

_verbose_exceptions: bool = _get_verbose_exceptions_envvar()


def set_verbose_exceptions(enabled: bool = False) -> None:
"""Programmatically set whether to use verbose printing for uncaught errors."""
global _verbose_exceptions
_verbose_exceptions = enabled
from . import exceptions, formatting


def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -> list[str]:
Expand All @@ -77,7 +50,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException
also printed.
"""
if isinstance(value, exceptions.DSLError):
exc_strs = _format_uncaught_error(value, _verbose_exceptions)
exc_strs = _format_uncaught_error(value, config.VERBOSE_EXCEPTIONS)
print("".join(exc_strs), file=sys.stderr)
else:
fallback(type_, value, tb)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __call__(self, *args):
def make_node(o):
if isinstance(o, Node):
return o
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if callable(o):
if o.__name__ == "<lambda>":
return lambdadef(o)
Expand All @@ -156,8 +158,6 @@ def make_node(o):
return OffsetLiteral(value=o.value)
if isinstance(o, core_defs.Scalar):
return im.literal_from_value(o)
if isinstance(o, common.Dimension):
return AxisLiteral(value=o.value)
if isinstance(o, tuple):
return _f("make_tuple", *(make_node(arg) for arg in o))
if o is None:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@enum.unique
class LiftMode(enum.Enum):
FORCE_INLINE = enum.auto()
FORCE_TEMPORARIES = enum.auto()
USE_TEMPORARIES = enum.auto()
SIMPLE_HEURISTIC = enum.auto()


Expand All @@ -47,7 +47,7 @@ def _inline_lifts(ir, lift_mode):
return InlineLifts().visit(ir)
elif lift_mode == LiftMode.SIMPLE_HEURISTIC:
return InlineLifts(simple_inline_heuristic.is_eligible_for_inlining).visit(ir)
elif lift_mode == LiftMode.FORCE_TEMPORARIES:
elif lift_mode == LiftMode.USE_TEMPORARIES:
return InlineLifts(
flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT
| InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet.
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/otf/compilation/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
_session_cache_dir = tempfile.TemporaryDirectory(prefix="gt4py_session_")

_session_cache_dir_path = pathlib.Path(_session_cache_dir.name)
_persistent_cache_dir_path = config.BUILD_CACHE_DIR


def _serialize_param(parameter: interface.Parameter) -> str:
Expand Down Expand Up @@ -72,7 +71,7 @@ def get_cache_folder(
case config.BuildCacheLifetime.SESSION:
base_path = _session_cache_dir_path
case config.BuildCacheLifetime.PERSISTENT:
base_path = _persistent_cache_dir_path
base_path = config.BUILD_CACHE_DIR
case _:
raise ValueError("Unsupported caching lifetime.")

Expand Down
Loading

0 comments on commit 5c81c03

Please sign in to comment.