diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index fcb71e4e6d..37df68b022 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable import numpy as np @@ -79,12 +79,25 @@ } -def format_builtin(bultin: str, *args: Any) -> str: - if bultin in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[bultin] +def builtin_if(*args: Any) -> str: + cond, true_val, false_val = args + return f"{true_val} if {cond} else {false_val}" + + +GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { + "if_": builtin_if, +} + + +def format_builtin(builtin: str, *args: Any) -> str: + if builtin in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin] + return fmt.format(*args) + elif builtin in GENERAL_BUILTIN_MAPPING: + expr_func = GENERAL_BUILTIN_MAPPING[builtin] + return expr_func(*args) else: - raise NotImplementedError(f"'{bultin}' not implemented.") - return fmt.format(*args) + raise NotImplementedError(f"'{builtin}' not implemented.") class PythonCodegen(codegen.TemplatedGenerator): @@ -103,12 +116,6 @@ def _visit_deref(self, node: gtir.FunCall) -> str: return self.visit(node.args[0]) raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - def _visit_numeric_builtin(self, node: gtir.FunCall) -> str: - assert isinstance(node.fun, gtir.SymRef) - fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - def visit_FunCall(self, node: gtir.FunCall) -> str: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index a4d04511fa..9fadb6ada4 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -1312,3 +1312,38 @@ def test_gtir_let_lambda_with_cond(): b = np.empty_like(a) sdfg(pred=np.bool_(s), x=a, y=b, **FSYMBOLS) assert np.allclose(b, a if s else a * 2) + + +def test_gtir_if_values(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="if_values", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.op_as_fieldop("if_", domain)( + im.op_as_fieldop("less", domain)("x", "y"), "x", "y" + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.empty_like(a) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(x=a, y=b, z=c, **FSYMBOLS) + assert np.allclose(c, np.where(a < b, a, b))