From c2ef68b4bb5239fa55308c7142ec9107a16bb154 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 30 Aug 2024 08:12:54 +0200 Subject: [PATCH 1/5] enable scalar math operations in embedded --- src/gt4py/next/ffront/fbuiltins.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cd0daffb49..fae8b91dd0 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -9,6 +9,7 @@ import dataclasses import functools import inspect +import math from builtins import bool, float, int, tuple from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast @@ -227,9 +228,11 @@ def astype( def _make_unary_math_builtin(name: str) -> None: def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) - # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: ERA001 [commented-out-code] - # return getattr(math, name)(value)# noqa: ERA001 [commented-out-code] - raise NotImplementedError() + assert core_defs.is_scalar_type( + value + ) # default implementation for scalars, Fields are handled via dispatch + + return getattr(math, name)(value) impl.__name__ = name globals()[name] = BuiltInFunction(impl) From 9169dc199e96f6b9992f032f7b7c0c00a1694bea Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 30 Aug 2024 08:18:50 +0200 Subject: [PATCH 2/5] fix comment --- src/gt4py/next/ffront/fbuiltins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index fae8b91dd0..df00ceb510 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -227,7 +227,7 @@ def astype( def _make_unary_math_builtin(name: str) -> None: def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: - # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) + # TODO(havogt): enable tests in `test_math_builtin_execution.py` assert core_defs.is_scalar_type( value ) # default implementation for scalars, Fields are handled via dispatch From ce53049fc93d314902540da1beb3b77d7a859c91 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Mon, 9 Sep 2024 09:47:06 +0200 Subject: [PATCH 3/5] Fix access to arc-functions without standard names --- src/gt4py/next/ffront/fbuiltins.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index df00ceb510..be8dc247ef 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -226,13 +226,18 @@ def astype( def _make_unary_math_builtin(name: str) -> None: + if name.startswith("arc") and not hasattr(math, name): + _math_builtin = getattr(math, f"a{name[3:]}") + else: + _math_builtin = getattr(math, name) + def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable tests in `test_math_builtin_execution.py` assert core_defs.is_scalar_type( value ) # default implementation for scalars, Fields are handled via dispatch - return getattr(math, name)(value) + return _math_builtin(value) impl.__name__ = name globals()[name] = BuiltInFunction(impl) From 84043f3e67a09691dc171f0aec96b38eb42e8474 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Mon, 9 Sep 2024 12:45:36 +0200 Subject: [PATCH 4/5] Fix bug --- src/gt4py/next/ffront/fbuiltins.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index be8dc247ef..c1282eb167 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -226,10 +226,14 @@ def astype( def _make_unary_math_builtin(name: str) -> None: - if name.startswith("arc") and not hasattr(math, name): + if hasattr(math, name): + _math_builtin = getattr(math, name) + elif name.startswith("arc"): _math_builtin = getattr(math, f"a{name[3:]}") + elif name in __builtins__: + _math_builtin = __builtins__[name] else: - _math_builtin = getattr(math, name) + raise AssertionError(f"Invalid find builtin '{name}'.") def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable tests in `test_math_builtin_execution.py` From 7c68a096f7cc981ac539a4ca28c300f0390ec4e8 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Mon, 9 Sep 2024 13:36:29 +0200 Subject: [PATCH 5/5] Adding custom implementations for builtins --- src/gt4py/next/ffront/fbuiltins.py | 74 ++++++++++++++++-------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index c1282eb167..f5381b3c72 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,7 +11,7 @@ import inspect import math from builtins import bool, float, int, tuple -from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast +from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np from numpy import float32, float64, int32, int64 @@ -197,43 +197,47 @@ def astype( return core_defs.dtype(type_).scalar_type(value) -UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"] - -UNARY_MATH_FP_BUILTIN_NAMES = [ - "sin", - "cos", - "tan", - "arcsin", - "arccos", - "arctan", - "sinh", - "cosh", - "tanh", - "arcsinh", - "arccosh", - "arctanh", - "sqrt", - "exp", - "log", - "gamma", - "cbrt", - "floor", - "ceil", - "trunc", -] - -UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES = ["isfinite", "isinf", "isnan"] +_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs} +UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()] + +_UNARY_MATH_FP_BUILTIN_IMPL: Final = { + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "arcsin": math.asin, + "arccos": math.acos, + "arctan": math.atan, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + "arcsinh": math.asinh, + "arccosh": math.acosh, + "arctanh": math.atanh, + "sqrt": math.sqrt, + "exp": math.exp, + "log": math.log, + "gamma": math.gamma, + "cbrt": math.cbrt if hasattr(math, "cbrt") else np.cbrt, # match.cbrt() only added in 3.11 + "floor": math.floor, + "ceil": math.ceil, + "trunc": math.trunc, +} +UNARY_MATH_FP_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_BUILTIN_IMPL.keys()] + +_UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL: Final = { + "isfinite": math.isfinite, + "isinf": math.isinf, + "isnan": math.isnan, +} +UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL.keys()] def _make_unary_math_builtin(name: str) -> None: - if hasattr(math, name): - _math_builtin = getattr(math, name) - elif name.startswith("arc"): - _math_builtin = getattr(math, f"a{name[3:]}") - elif name in __builtins__: - _math_builtin = __builtins__[name] - else: - raise AssertionError(f"Invalid find builtin '{name}'.") + _math_builtin = ( + _UNARY_MATH_NUMBER_BUILTIN_IMPL + | _UNARY_MATH_FP_BUILTIN_IMPL + | _UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL + )[name] def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable tests in `test_math_builtin_execution.py`