From aeff778f418f83420395817f13b7ee32c77e2263 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sun, 30 Jun 2024 10:08:40 +0000 Subject: [PATCH 01/29] nothing --- .../ffront_tests/test_foast_to_gtir.py | 640 ++++++++++++++++++ 1 file changed, 640 insertions(+) create mode 100644 tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py new file mode 100644 index 0000000000..d45b23c68c --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -0,0 +1,640 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the +# lowering can (and often does) make all of them fail. Once we have embedded field view we want to +# switch to executing the different cases here; once with a regular backend (i.e. including +# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering +# should be correct. + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import gt4py.next as gtx +from gt4py.next import float32, float64, int32, int64, neighbor_sum +from gt4py.next.ffront import type_specifications as ts_ffront +from gt4py.next.ffront.ast_passes import single_static_assign as ssa +from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts, type_translation +from gt4py.next.iterator.type_system import type_specifications as it_ts + + +IDim = gtx.Dimension("IDim") +Edge = gtx.Dimension("Edge") +Vertex = gtx.Dimension("Vertex") +Cell = gtx.Dimension("Cell") +V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. + + +def debug_itir(tree): + """Compare tree snippets while debugging.""" + from devtools import debug + + from gt4py.eve.codegen import format_python_source + from gt4py.next.program_processors.runners.roundtrip import EmbeddedDSL + + debug(format_python_source(EmbeddedDSL.apply(tree))) + + +from gt4py.eve import PreserveLocationVisitor, NodeTranslator +from gt4py.eve import utils +import dataclasses +from gt4py.next.ffront import field_operator_ast as foast +from typing import Optional, Any + + +@dataclasses.dataclass +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): + """ + Lower FieldOperator AST (FOAST) to GTIR. + """ + + uid_generator: utils.UIDGenerator = dataclasses.field(default_factory=utils.UIDGenerator) + + @classmethod + def apply(cls, node: foast.LocatedNode) -> itir.Expr: + return cls().visit(node) + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs + ) -> itir.FunctionDefinition: + params = self.visit(node.params) + return itir.FunctionDefinition( + id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) + ) + + def visit_BlockStmt( + self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + for stmt in reversed(node.stmts): + inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) + assert inner_expr + assert isinstance(inner_expr, itir.Node) + return inner_expr + + def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: + return im.sym(node.id) + + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + return im.ref(node.id) + + def visit_Return( + self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + return self.visit(node.value, **kwargs) + + def visit_Node(self, node: foast.Node, **kwargs: Any): + raise NotImplementedError( + f"Translation of '{node}' of type '{type(node)}' not implemented." + ) + + +def test_copy(): + def copy_field(inp: gtx.Field[[TDim], float64]): + return inp + + parsed = FieldOperatorParser.apply_to_function(copy_field) + lowered = FieldOperatorLowering.apply(parsed) + print(lowered) + + assert lowered.id == "copy_field" + assert lowered.expr == im.ref("inp") + + +# def test_scalar_arg(): +# def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: +# return alpha * bar + +# parsed = FieldOperatorParser.apply_to_function(scalar_arg) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("multiplies")( +# "alpha", "bar" +# ) # no difference to non-scalar arg + +# assert lowered.expr == reference + + +# def test_multicopy(): +# def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): +# return inp1, inp2 + +# parsed = FieldOperatorParser.apply_to_function(multicopy) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.make_tuple("inp1", "inp2") + +# assert lowered.expr == reference + + +# def test_arithmetic(): +# def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): +# return inp1 + inp2 + +# parsed = FieldOperatorParser.apply_to_function(arithmetic) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") + +# assert lowered.expr == reference + + +# def test_shift(): +# Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) + +# def shift_by_one(inp: gtx.Field[[IDim], float64]): +# return inp(Ioff[1]) + +# parsed = FieldOperatorParser.apply_to_function(shift_by_one) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") + +# assert lowered.expr == reference + + +# def test_negative_shift(): +# Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) + +# def shift_by_one(inp: gtx.Field[[IDim], float64]): +# return inp(Ioff[-1]) + +# parsed = FieldOperatorParser.apply_to_function(shift_by_one) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") + +# assert lowered.expr == reference + + +# def test_temp_assignment(): +# def copy_field(inp: gtx.Field[[TDim], float64]): +# tmp = inp +# inp = tmp +# tmp2 = inp +# return tmp2 + +# parsed = FieldOperatorParser.apply_to_function(copy_field) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.let(ssa.unique_name("tmp", 0), "inp")( +# im.let( +# ssa.unique_name("inp", 0), +# ssa.unique_name("tmp", 0), +# )( +# im.let( +# ssa.unique_name("tmp2", 0), +# ssa.unique_name("inp", 0), +# )(ssa.unique_name("tmp2", 0)) +# ) +# ) + +# assert lowered.expr == reference + + +# def test_unary_ops(): +# def unary(inp: gtx.Field[[TDim], float64]): +# tmp = +inp +# tmp = -tmp +# return tmp + +# parsed = FieldOperatorParser.apply_to_function(unary) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.let( +# ssa.unique_name("tmp", 0), +# im.promote_to_lifted_stencil("plus")( +# im.promote_to_const_iterator(im.literal("0", "float64")), "inp" +# ), +# )( +# im.let( +# ssa.unique_name("tmp", 1), +# im.promote_to_lifted_stencil("minus")( +# im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) +# ), +# )(ssa.unique_name("tmp", 1)) +# ) + +# assert lowered.expr == reference + + +# def test_unpacking(): +# """Unpacking assigns should get separated.""" + +# def unpacking( +# inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] +# ) -> gtx.Field[[TDim], float64]: +# tmp1, tmp2 = inp1, inp2 # noqa +# return tmp1 + +# parsed = FieldOperatorParser.apply_to_function(unpacking) +# lowered = FieldOperatorLowering.apply(parsed) + +# tuple_expr = im.make_tuple("inp1", "inp2") +# tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") +# tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") + +# reference = im.let("__tuple_tmp_0", tuple_expr)( +# im.let( +# ssa.unique_name("tmp1", 0), +# tuple_access_0, +# )( +# im.let( +# ssa.unique_name("tmp2", 0), +# tuple_access_1, +# )(ssa.unique_name("tmp1", 0)) +# ) +# ) + +# assert lowered.expr == reference + + +# def test_annotated_assignment(): +# pytest.xfail("Annotated assignments are not properly supported at the moment.") + +# def copy_field(inp: gtx.Field[[TDim], float64]): +# tmp: gtx.Field[[TDim], float64] = inp +# return tmp + +# parsed = FieldOperatorParser.apply_to_function(copy_field) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) + +# assert lowered.expr == reference + + +# def test_call(): +# # create something that appears to the lowering like a field operator. +# # we could also create an actual field operator, but we want to avoid +# # using such heavy constructs for testing the lowering. +# field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) +# identity = SimpleNamespace( +# __gt_type__=lambda: ts_ffront.FieldOperatorType( +# definition=ts.FunctionType( +# pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type +# ) +# ) +# ) + +# def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: +# return identity(inp) + +# parsed = FieldOperatorParser.apply_to_function(call) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.call("identity")("inp") + +# assert lowered.expr == reference + + +# def test_temp_tuple(): +# """Returning a temp tuple should work.""" + +# def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): +# tmp = a, b +# return tmp + +# parsed = FieldOperatorParser.apply_to_function(temp_tuple) +# lowered = FieldOperatorLowering.apply(parsed) + +# tuple_expr = im.make_tuple("a", "b") +# reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) + +# assert lowered.expr == reference + + +# def test_unary_not(): +# def unary_not(cond: gtx.Field[[TDim], "bool"]): +# return not cond + +# parsed = FieldOperatorParser.apply_to_function(unary_not) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("not_")("cond") + +# assert lowered.expr == reference + + +# def test_binary_plus(): +# def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a + b + +# parsed = FieldOperatorParser.apply_to_function(plus) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("plus")("a", "b") + +# assert lowered.expr == reference + + +# def test_add_scalar_literal_to_field(): +# def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: +# return 2.0 + a + +# parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("plus")( +# im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" +# ) + +# assert lowered.expr == reference + + +# def test_add_scalar_literals(): +# def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: +# tmp = int32(1) + int32("1") +# return a + tmp + +# parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.let( +# ssa.unique_name("tmp", 0), +# im.promote_to_lifted_stencil("plus")( +# im.promote_to_const_iterator(im.literal("1", "int32")), +# im.promote_to_const_iterator(im.literal("1", "int32")), +# ), +# )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) + +# assert lowered.expr == reference + + +# def test_binary_mult(): +# def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a * b + +# parsed = FieldOperatorParser.apply_to_function(mult) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("multiplies")("a", "b") + +# assert lowered.expr == reference + + +# def test_binary_minus(): +# def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a - b + +# parsed = FieldOperatorParser.apply_to_function(minus) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("minus")("a", "b") + +# assert lowered.expr == reference + + +# def test_binary_div(): +# def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a / b + +# parsed = FieldOperatorParser.apply_to_function(division) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("divides")("a", "b") + +# assert lowered.expr == reference + + +# def test_binary_and(): +# def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): +# return a & b + +# parsed = FieldOperatorParser.apply_to_function(bit_and) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("and_")("a", "b") + +# assert lowered.expr == reference + + +# def test_scalar_and(): +# def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: +# return a & False + +# parsed = FieldOperatorParser.apply_to_function(scalar_and) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("and_")( +# "a", im.promote_to_const_iterator(im.literal("False", "bool")) +# ) + +# assert lowered.expr == reference + + +# def test_binary_or(): +# def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): +# return a | b + +# parsed = FieldOperatorParser.apply_to_function(bit_or) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("or_")("a", "b") + +# assert lowered.expr == reference + + +# def test_compare_scalars(): +# def comp_scalars() -> bool: +# return 3 > 4 + +# parsed = FieldOperatorParser.apply_to_function(comp_scalars) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("greater")( +# im.promote_to_const_iterator(im.literal("3", "int32")), +# im.promote_to_const_iterator(im.literal("4", "int32")), +# ) + +# assert lowered.expr == reference + + +# def test_compare_gt(): +# def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a > b + +# parsed = FieldOperatorParser.apply_to_function(comp_gt) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("greater")("a", "b") + +# assert lowered.expr == reference + + +# def test_compare_lt(): +# def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): +# return a < b + +# parsed = FieldOperatorParser.apply_to_function(comp_lt) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("less")("a", "b") + +# assert lowered.expr == reference + + +# def test_compare_eq(): +# def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): +# return a == b + +# parsed = FieldOperatorParser.apply_to_function(comp_eq) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("eq")("a", "b") + +# assert lowered.expr == reference + + +# def test_compare_chain(): +# def compare_chain( +# a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] +# ) -> gtx.Field[[IDim], bool]: +# return a > b > c + +# parsed = FieldOperatorParser.apply_to_function(compare_chain) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil("and_")( +# im.promote_to_lifted_stencil("greater")("a", "b"), +# im.promote_to_lifted_stencil("greater")("b", "c"), +# ) + +# assert lowered.expr == reference + + +# def test_reduction_lowering_simple(): +# def reduction(edge_f: gtx.Field[[Edge], float64]): +# return neighbor_sum(edge_f(V2E), axis=V2EDim) + +# parsed = FieldOperatorParser.apply_to_function(reduction) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.promote_to_lifted_stencil( +# im.call( +# im.call("reduce")( +# "plus", +# im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), +# ) +# ) +# )(im.lifted_neighbors("V2E", "edge_f")) + +# assert lowered.expr == reference + + +# def test_reduction_lowering_expr(): +# def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): +# e1_nbh = e1(V2E) +# return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) + +# parsed = FieldOperatorParser.apply_to_function(reduction) +# lowered = FieldOperatorLowering.apply(parsed) + +# mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( +# im.promote_to_lifted_stencil("make_const_list")( +# im.promote_to_const_iterator(im.literal("1.1", "float64")) +# ), +# im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), +# ) + +# reference = im.let( +# ssa.unique_name("e1_nbh", 0), +# im.lifted_neighbors("V2E", "e1"), +# )( +# im.promote_to_lifted_stencil( +# im.call( +# im.call("reduce")( +# "plus", +# im.deref( +# im.promote_to_const_iterator(im.literal(value="0", typename="float64")) +# ), +# ) +# ) +# )(mapped) +# ) + +# assert lowered.expr == reference + + +# def test_builtin_int_constructors(): +# def int_constrs() -> tuple[int32, int32, int64, int32, int64]: +# return 1, int32(1), int64(1), int32("1"), int64("1") + +# parsed = FieldOperatorParser.apply_to_function(int_constrs) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.make_tuple( +# im.promote_to_const_iterator(im.literal("1", "int32")), +# im.promote_to_const_iterator(im.literal("1", "int32")), +# im.promote_to_const_iterator(im.literal("1", "int64")), +# im.promote_to_const_iterator(im.literal("1", "int32")), +# im.promote_to_const_iterator(im.literal("1", "int64")), +# ) + +# assert lowered.expr == reference + + +# def test_builtin_float_constructors(): +# def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: +# return ( +# 0.1, +# float(0.1), +# float32(0.1), +# float64(0.1), +# float(".1"), +# float32(".1"), +# float64(".1"), +# ) + +# parsed = FieldOperatorParser.apply_to_function(float_constrs) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.make_tuple( +# im.promote_to_const_iterator(im.literal("0.1", "float64")), +# im.promote_to_const_iterator(im.literal("0.1", "float64")), +# im.promote_to_const_iterator(im.literal("0.1", "float32")), +# im.promote_to_const_iterator(im.literal("0.1", "float64")), +# im.promote_to_const_iterator(im.literal(".1", "float64")), +# im.promote_to_const_iterator(im.literal(".1", "float32")), +# im.promote_to_const_iterator(im.literal(".1", "float64")), +# ) + +# assert lowered.expr == reference + + +# def test_builtin_bool_constructors(): +# def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: +# return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") + +# parsed = FieldOperatorParser.apply_to_function(bool_constrs) +# lowered = FieldOperatorLowering.apply(parsed) + +# reference = im.make_tuple( +# im.promote_to_const_iterator(im.literal(str(True), "bool")), +# im.promote_to_const_iterator(im.literal(str(False), "bool")), +# im.promote_to_const_iterator(im.literal(str(True), "bool")), +# im.promote_to_const_iterator(im.literal(str(False), "bool")), +# im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), +# im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), +# im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), +# im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), +# ) + +# assert lowered.expr == reference From 69195965b5c6a795474b999337a9c8f9a595f67b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sun, 30 Jun 2024 14:32:58 +0000 Subject: [PATCH 02/29] basic opbinary --- src/gt4py/next/ffront/foast_to_gtir.py | 434 ++++++++++++++++++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 38 +- .../ffront_tests/test_foast_to_gtir.py | 74 +-- 3 files changed, 482 insertions(+), 64 deletions(-) create mode 100644 src/gt4py/next/ffront/foast_to_gtir.py diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py new file mode 100644 index 0000000000..fd4bb7d370 --- /dev/null +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -0,0 +1,434 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from typing import Any, Callable, Optional + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.eve.extended_typing import Never +from gt4py.eve.utils import UIDGenerator +from gt4py.next.ffront import ( + dialect_ast_enums, + fbuiltins, + field_operator_ast as foast, + lowering_utils, + stages as ffront_stages, + type_specifications as ts_ffront, +) +from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES +from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES +from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_info, type_specifications as ts + + +# def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: +# return FieldOperatorLowering.apply(inp.foast_node) + + +# def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: +# if not type_info.contains_local_field(node.type): +# return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) +# return lambda x: x + + +@dataclasses.dataclass +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): + """ + Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). + + The strategy is to lower every expression to lifted stencils, + i.e. taking iterators and returning iterator. + + Examples + -------- + >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser + >>> from gt4py.next import Field, Dimension, float64 + >>> + >>> IDim = Dimension("IDim") + >>> def fieldop(inp: Field[[IDim], "float64"]): + ... return inp + >>> + >>> parsed = FieldOperatorParser.apply_to_function(fieldop) + >>> lowered = FieldOperatorLowering.apply(parsed) + >>> type(lowered) + + >>> lowered.id + SymbolName('fieldop') + >>> lowered.params # doctest: +ELLIPSIS + [Sym(id=SymbolName('inp'))] + """ + + uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) + + @classmethod + def apply(cls, node: foast.LocatedNode) -> itir.Expr: + return cls().visit(node) + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs: Any + ) -> itir.FunctionDefinition: + params = self.visit(node.params) + return itir.FunctionDefinition( + id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) + ) # `expr` is a lifted stencil + + # def visit_FieldOperator( + # self, node: foast.FieldOperator, **kwargs: Any + # ) -> itir.FunctionDefinition: + # func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + + # new_body = func_definition.expr + + # return itir.FunctionDefinition( + # id=func_definition.id, params=func_definition.params, expr=new_body + # ) + + # def visit_ScanOperator( + # self, node: foast.ScanOperator, **kwargs: Any + # ) -> itir.FunctionDefinition: + # # note: we don't need the axis here as this is handled by the program + # # decorator + # assert isinstance(node.type, ts_ffront.ScanOperatorType) + + # # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. + # # In iterator IR we didn't properly specify if this is legal, + # # however after lift-inlining the expressions are transformed back to literals. + # forward = im.deref(self.visit(node.forward, **kwargs)) + # init = lowering_utils.process_elements( + # im.deref, self.visit(node.init, **kwargs), node.init.type + # ) + + # # lower definition function + # func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + # new_body = im.let( + # func_definition.params[0].id, + # # promote carry to iterator of tuples + # # (this is the only place in the lowering were a variable is captured in a lifted lambda) + # lowering_utils.to_tuples_of_iterator( + # im.promote_to_const_iterator(func_definition.params[0].id), + # [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] + # ), + # )( + # # the function itself returns a tuple of iterators, deref element-wise + # lowering_utils.process_elements( + # im.deref, func_definition.expr, node.type.definition.returns + # ) + # ) + + # stencil_args: list[itir.Expr] = [] + # assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + # for param, arg_type in zip( + # func_definition.params[1:], + # [*node.type.definition.pos_or_kw_args.values()][1:], + # strict=True, + # ): + # if isinstance(arg_type, ts.TupleType): + # # convert into iterator of tuples + # stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) + + # new_body = im.let( + # param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) + # )(new_body) + # else: + # stencil_args.append(im.ref(param.id)) + + # definition = itir.Lambda(params=func_definition.params, expr=new_body) + + # body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) + + # return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) + + def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: + raise AssertionError("Statements must always be visited in the context of a function.") + + def visit_Return( + self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + return self.visit(node.value, **kwargs) + + def visit_BlockStmt( + self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + for stmt in reversed(node.stmts): + inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) + assert inner_expr + return inner_expr + + # def visit_IfStmt( + # self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any + # ) -> itir.Expr: + # # the lowered if call doesn't need to be lifted as the condition can only originate + # # from a scalar value (and not a field) + # assert ( + # isinstance(node.condition.type, ts.ScalarType) + # and node.condition.type.kind == ts.ScalarKind.BOOL + # ) + + # cond = self.visit(node.condition, **kwargs) + + # return_kind: StmtReturnKind = deduce_stmt_return_kind(node) + + # common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols + + # if return_kind is StmtReturnKind.NO_RETURN: + # # pack the common symbols into a tuple + # common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) + + # # apply both branches and extract the common symbols through the prepared tuple + # true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) + # false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) + + # # unpack the common symbols' tuple for `inner_expr` + # for i, sym in enumerate(common_symbols.keys()): + # inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) + + # # here we assume neither branch returns + # return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( + # inner_expr + # ) + # elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: + # common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) + # common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) + + # # wrap the inner expression in a lambda function. note that this increases the + # # operation count if both branches are evaluated. + # inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") + # inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) + # inner_expr = im.call(inner_expr_name)(*common_symrefs) + + # true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) + # false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) + + # return im.let(inner_expr_name, inner_expr_evaluator)( + # im.if_(im.deref(cond), true_branch, false_branch) + # ) + + # assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN + + # # note that we do not duplicate `inner_expr` here since if both branches + # # return, `inner_expr` is ignored. + # true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) + # false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) + + # return im.if_(im.deref(cond), true_branch, false_branch) + + def visit_Assign( + self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( + inner_expr + ) + + def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: + return im.sym(node.id) + + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + return im.ref(node.id) + + # def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: + # return im.tuple_get(node.index, self.visit(node.value, **kwargs)) + + # def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: + # return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + + # def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: + # # TODO(tehrengruber): extend iterator ir to support unary operators + # dtype = type_info.extract_dtype(node.type) + # if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: + # if dtype.kind != ts.ScalarKind.BOOL: + # raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") + # return self._map("not_", node.operand) + + # return self._map( + # node.op.value, + # foast.Constant(value="0", type=dtype, location=node.location), + # node.operand, + # ) + + def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: + return self._map(node.op.value, node.left, node.right) + + # def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: + # op = "if_" + # args = (node.condition, node.true_expr, node.false_expr) + # lowered_args: list[itir.Expr] = [ + # lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) + # for arg in args + # ] + # if any(type_info.contains_local_field(arg.type) for arg in args): + # lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + # op = im.call("map_")(op) + + # return lowering_utils.to_tuples_of_iterator( + # im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type + # ) + + # def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + # return self._map(node.op.value, node.left, node.right) + + # def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # match node.args[0]: + # case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): + # shift_offset = im.shift(offset_name, offset_index) + # case foast.Name(id=offset_name): + # return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) + # case foast.Call(func=foast.Name(id="as_offset")): + # func_args = node.args[0] + # offset_dim = func_args.args[0] + # assert isinstance(offset_dim, foast.Name) + # shift_offset = im.shift( + # offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs)) + # ) + # case _: + # raise FieldOperatorLoweringError("Unexpected shift arguments!") + # return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( + # self.visit(node.func, **kwargs) + # ) + + # def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # if type_info.type_class(node.func.type) is ts.FieldType: + # return self._visit_shift(node, **kwargs) + # elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: + # return self._visit_math_built_in(node, **kwargs) + # elif isinstance(node.func, foast.Name) and node.func.id in ( + # FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + # ): + # visitor = getattr(self, f"_visit_{node.func.id}") + # return visitor(node, **kwargs) + # elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: + # return self._visit_type_constr(node, **kwargs) + # elif isinstance( + # node.func.type, + # (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), + # ): + # # ITIR has no support for keyword arguments. Instead, we concatenate both positional + # # and keyword arguments and use the unique order as given in the function signature. + # lowered_args, lowered_kwargs = type_info.canonicalize_arguments( + # node.func.type, + # self.visit(node.args, **kwargs), + # self.visit(node.kwargs, **kwargs), + # use_signature_ordering=True, + # ) + # result = im.call(self.visit(node.func, **kwargs))( + # *lowered_args, *lowered_kwargs.values() + # ) + + # # scan operators return an iterator of tuples, transform into tuples of iterator again + # if isinstance(node.func.type, ts_ffront.ScanOperatorType): + # result = lowering_utils.to_tuples_of_iterator( + # result, node.func.type.definition.returns + # ) + + # return result + + # raise AssertionError( + # f"Call to object of type '{type(node.func.type).__name__}' not understood." + # ) + + # def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + # assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) + # obj, new_type = node.args[0], node.args[1].id + # return lowering_utils.process_elements( + # lambda x: im.promote_to_lifted_stencil( + # im.lambda_("it")(im.call("cast_")("it", str(new_type))) + # )(x), + # self.visit(obj, **kwargs), + # obj.type, + # ) + + # def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + # condition, true_value, false_value = node.args + + # lowered_condition = self.visit(condition, **kwargs) + # return lowering_utils.process_elements( + # lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), + # [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], + # node.type, + # ) + + # _visit_concat_where = _visit_where + + # def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + # return self.visit(node.args[0], **kwargs) + + # def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + # return self._map(self.visit(node.func, **kwargs), *node.args) + + # def _make_reduction_expr( + # self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any + # ) -> itir.Expr: + # # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) + # it = self.visit(node.args[0], **kwargs) + # assert isinstance(node.kwargs["axis"].type, ts.DimensionType) + # val = im.call(im.call("reduce")(op, im.deref(init_expr))) + # return im.promote_to_lifted_stencil(val)(it) + + # def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # dtype = type_info.extract_dtype(node.type) + # return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) + + # def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # dtype = type_info.extract_dtype(node.type) + # min_value, _ = type_info.arithmetic_bounds(dtype) + # init_expr = self._make_literal(str(min_value), dtype) + # return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) + + # def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # dtype = type_info.extract_dtype(node.type) + # _, max_value = type_info.arithmetic_bounds(dtype) + # init_expr = self._make_literal(str(max_value), dtype) + # return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) + + # def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + # if isinstance(node.args[0], foast.Constant): + # node_kind = self.visit(node.type).kind.name.lower() + # target_type = fbuiltins.BUILTINS[node_kind] + # source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] + # if target_type is bool and source_type is not bool: + # return im.promote_to_const_iterator( + # im.literal(str(bool(source_type(node.args[0].value))), "bool") + # ) + # return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) + # raise FieldOperatorLoweringError( + # f"Encountered a type cast, which is not supported: {node}." + # ) + + # def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: + # # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; + # # the following constructs work if they are removed by inlining. + # if isinstance(type_, ts.TupleType): + # return im.make_tuple( + # *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) + # ) + # elif isinstance(type_, ts.ScalarType): + # typename = type_.kind.name.lower() + # return im.promote_to_const_iterator(im.literal(str(val), typename)) + # raise ValueError(f"Unsupported literal type '{type_}'.") + + # def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: + # return self._make_literal(node.value, node.type) + + def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + lowered_args = [self.visit(arg, **kwargs) for arg in args] + if any(type_info.contains_local_field(arg.type) for arg in args): + raise NotImplementedError("TODO: Local fields not supported") + # lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + # op = im.call("map_")(op) + + return im.op_as_fieldop(im.call(op))(*lowered_args) + + +class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 40bfc0ab75..c2ded65fc1 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import typing -from typing import Callable, Iterable, Union +from typing import Callable, Iterable, Union, Optional from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -85,6 +85,8 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti return ref(literal_or_expr) elif core_defs.is_scalar_type(literal_or_expr): return literal_from_value(literal_or_expr) + elif literal_or_expr is None: + return itir.NoneLiteral() assert isinstance(literal_or_expr, itir.Expr) return literal_or_expr @@ -253,6 +255,13 @@ def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) +def as_fieldop(stencil, domain=None): + """Creates a field_operator from a stencil.""" + args = [stencil] + if domain is not None: + args.append(domain) + return call(call("as_fieldop")(*args)) + class let: """ @@ -397,6 +406,33 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def op_as_fieldop( + op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None +) -> Callable[..., itir.FunCall]: + """ + Promotes a function `op` to a field_operator. + + `op` is a function from values to value. + + Returns: + A function from Fields to Field. + + Examples + -------- + >>> str(op_as_fieldop("op")("a", "b")) + '(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' + """ + if isinstance(op, (str, itir.SymRef, itir.Lambda)): + op = call(op) + + def _impl(*its: itir.Expr) -> itir.FunCall: + args = [ + f"__arg{i}" for i in range(len(its)) + ] # TODO: `op` must not contain `SymRef(id="__argX")` + return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its) + + return _impl + def map_(op): """Create a `map_` call.""" diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index d45b23c68c..4950a97b3e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -28,6 +28,7 @@ from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation @@ -53,66 +54,14 @@ def debug_itir(tree): debug(format_python_source(EmbeddedDSL.apply(tree))) -from gt4py.eve import PreserveLocationVisitor, NodeTranslator -from gt4py.eve import utils -import dataclasses -from gt4py.next.ffront import field_operator_ast as foast -from typing import Optional, Any -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to GTIR. - """ - - uid_generator: utils.UIDGenerator = dataclasses.field(default_factory=utils.UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - assert isinstance(inner_expr, itir.Node) - return inner_expr - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_Node(self, node: foast.Node, **kwargs: Any): - raise NotImplementedError( - f"Translation of '{node}' of type '{type(node)}' not implemented." - ) - - -def test_copy(): +def test_return(): def copy_field(inp: gtx.Field[[TDim], float64]): return inp parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) - print(lowered) assert lowered.id == "copy_field" assert lowered.expr == im.ref("inp") @@ -125,9 +74,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): # parsed = FieldOperatorParser.apply_to_function(scalar_arg) # lowered = FieldOperatorLowering.apply(parsed) -# reference = im.promote_to_lifted_stencil("multiplies")( -# "alpha", "bar" -# ) # no difference to non-scalar arg +# reference = im.op_as_fieldop("multiplies")("alpha", "bar") # no difference to non-scalar arg # assert lowered.expr == reference @@ -144,16 +91,17 @@ def copy_field(inp: gtx.Field[[TDim], float64]): # assert lowered.expr == reference -# def test_arithmetic(): -# def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): -# return inp1 + inp2 +def test_arithmetic(): + def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): + return inp1 + inp2 -# parsed = FieldOperatorParser.apply_to_function(arithmetic) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(arithmetic) + lowered = FieldOperatorLowering.apply(parsed) + print(lowered) -# reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") + reference = im.op_as_fieldop("plus")("inp1", "inp2") -# assert lowered.expr == reference + assert lowered.expr == reference # def test_shift(): From 87883a6a60821785d406e00de20ea3ad5b2f665c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 3 Jul 2024 16:28:57 +0200 Subject: [PATCH 03/29] all tests updated --- src/gt4py/next/ffront/foast_to_gtir.py | 238 +++---- src/gt4py/next/iterator/ir_utils/ir_makers.py | 16 +- .../ffront_tests/test_foast_to_gtir.py | 672 +++++++++--------- 3 files changed, 469 insertions(+), 457 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index fd4bb7d370..2177ae4369 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -22,13 +22,10 @@ dialect_ast_enums, fbuiltins, field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, type_specifications as ts_ffront, ) from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind +from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts @@ -38,10 +35,10 @@ # return FieldOperatorLowering.apply(inp.foast_node) -# def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: -# if not type_info.contains_local_field(node.type): -# return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) -# return lambda x: x +def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node.type): + return lambda x: im.op_as_fieldop("make_const_list")(x) + return lambda x: x @dataclasses.dataclass @@ -119,7 +116,7 @@ def visit_FunctionDefinition( # # (this is the only place in the lowering were a variable is captured in a lifted lambda) # lowering_utils.to_tuples_of_iterator( # im.promote_to_const_iterator(func_definition.params[0].id), - # [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] + # [*node.type.definition.pos_or_kw_args.values()][0], # [unnecessary-iterable-allocation-for-first-element] # ), # )( # # the function itself returns a tuple of iterators, deref element-wise @@ -238,19 +235,21 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: return im.ref(node.id) - # def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - # return im.tuple_get(node.index, self.visit(node.value, **kwargs)) + def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: + return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - # def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - # return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: + return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - # def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # # TODO(tehrengruber): extend iterator ir to support unary operators - # dtype = type_info.extract_dtype(node.type) - # if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - # if dtype.kind != ts.ScalarKind.BOOL: - # raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - # return self._map("not_", node.operand) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: + # TODO(tehrengruber): extend iterator ir to support unary operators + dtype = type_info.extract_dtype(node.type) + if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: + if dtype.kind != ts.ScalarKind.BOOL: + raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") + return self._map("not_", node.operand) + + raise NotImplementedError("TODO neg/pos") # return self._map( # node.op.value, @@ -276,67 +275,68 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: # im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type # ) - # def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - # return self._map(node.op.value, node.left, node.right) - - # def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # match node.args[0]: - # case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - # shift_offset = im.shift(offset_name, offset_index) - # case foast.Name(id=offset_name): - # return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # case foast.Call(func=foast.Name(id="as_offset")): - # func_args = node.args[0] - # offset_dim = func_args.args[0] - # assert isinstance(offset_dim, foast.Name) - # shift_offset = im.shift( - # offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs)) - # ) - # case _: - # raise FieldOperatorLoweringError("Unexpected shift arguments!") - # return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( - # self.visit(node.func, **kwargs) - # ) - - # def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # if type_info.type_class(node.func.type) is ts.FieldType: - # return self._visit_shift(node, **kwargs) - # elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - # return self._visit_math_built_in(node, **kwargs) - # elif isinstance(node.func, foast.Name) and node.func.id in ( - # FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - # ): - # visitor = getattr(self, f"_visit_{node.func.id}") - # return visitor(node, **kwargs) - # elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - # return self._visit_type_constr(node, **kwargs) - # elif isinstance( - # node.func.type, - # (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - # ): - # # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # # and keyword arguments and use the unique order as given in the function signature. - # lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - # node.func.type, - # self.visit(node.args, **kwargs), - # self.visit(node.kwargs, **kwargs), - # use_signature_ordering=True, - # ) - # result = im.call(self.visit(node.func, **kwargs))( - # *lowered_args, *lowered_kwargs.values() - # ) - - # # scan operators return an iterator of tuples, transform into tuples of iterator again - # if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # result = lowering_utils.to_tuples_of_iterator( - # result, node.func.type.definition.returns - # ) + def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + return self._map(node.op.value, node.left, node.right) - # return result + def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + match node.args[0]: + case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): + shift_offset = im.shift(offset_name, offset_index) + case foast.Name(id=offset_name): + return im.as_fieldop_neighbors(str(offset_name), self.visit(node.func, **kwargs)) + # case foast.Call(func=foast.Name(id="as_offset")): + # func_args = node.args[0] + # offset_dim = func_args.args[0] + # assert isinstance(offset_dim, foast.Name) + # shift_offset = im.shift( + # offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs)) + # ) + case _: + raise FieldOperatorLoweringError("Unexpected shift arguments!") + return im.as_fieldop(im.lambda_("it")(im.deref(shift_offset("it"))))( + self.visit(node.func, **kwargs) + ) - # raise AssertionError( - # f"Call to object of type '{type(node.func.type).__name__}' not understood." - # ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + if type_info.type_class(node.func.type) is ts.FieldType: + return self._visit_shift(node, **kwargs) + # elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: + # return self._visit_math_built_in(node, **kwargs) + elif isinstance(node.func, foast.Name) and node.func.id in ( + FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + ): + visitor = getattr(self, f"_visit_{node.func.id}") + return visitor(node, **kwargs) + elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: + return self._visit_type_constr(node, **kwargs) + elif isinstance( + node.func.type, + (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), + ): + # ITIR has no support for keyword arguments. Instead, we concatenate both positional + # and keyword arguments and use the unique order as given in the function signature. + lowered_args, lowered_kwargs = type_info.canonicalize_arguments( + node.func.type, + self.visit(node.args, **kwargs), + self.visit(node.kwargs, **kwargs), + use_signature_ordering=True, + ) + result = im.call(self.visit(node.func, **kwargs))( + *lowered_args, *lowered_kwargs.values() + ) + + # scan operators return an iterator of tuples, transform into tuples of iterator again + if isinstance(node.func.type, ts_ffront.ScanOperatorType): + raise NotImplementedError("TODO") + # result = lowering_utils.to_tuples_of_iterator( + # result, node.func.type.definition.returns + # ) + + return result + + raise AssertionError( + f"Call to object of type '{type(node.func.type).__name__}' not understood." + ) # def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: # assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) @@ -367,18 +367,18 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: # def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: # return self._map(self.visit(node.func, **kwargs), *node.args) - # def _make_reduction_expr( - # self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - # ) -> itir.Expr: - # # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - # it = self.visit(node.args[0], **kwargs) - # assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - # val = im.call(im.call("reduce")(op, im.deref(init_expr))) - # return im.promote_to_lifted_stencil(val)(it) + def _make_reduction_expr( + self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any + ) -> itir.Expr: + # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) + it = self.visit(node.args[0], **kwargs) + assert isinstance(node.kwargs["axis"].type, ts.DimensionType) + val = im.call(im.call("reduce")(op, init_expr)) + return im.op_as_fieldop(val)(it) - # def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # dtype = type_info.extract_dtype(node.type) - # return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) + def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + dtype = type_info.extract_dtype(node.type) + return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) # def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # dtype = type_info.extract_dtype(node.type) @@ -392,41 +392,39 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: # init_expr = self._make_literal(str(max_value), dtype) # return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - # def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # if isinstance(node.args[0], foast.Constant): - # node_kind = self.visit(node.type).kind.name.lower() - # target_type = fbuiltins.BUILTINS[node_kind] - # source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] - # if target_type is bool and source_type is not bool: - # return im.promote_to_const_iterator( - # im.literal(str(bool(source_type(node.args[0].value))), "bool") - # ) - # return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) - # raise FieldOperatorLoweringError( - # f"Encountered a type cast, which is not supported: {node}." - # ) - - # def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # # the following constructs work if they are removed by inlining. - # if isinstance(type_, ts.TupleType): - # return im.make_tuple( - # *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - # ) - # elif isinstance(type_, ts.ScalarType): - # typename = type_.kind.name.lower() - # return im.promote_to_const_iterator(im.literal(str(val), typename)) - # raise ValueError(f"Unsupported literal type '{type_}'.") + def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + if isinstance(node.args[0], foast.Constant): + node_kind = self.visit(node.type).kind.name.lower() + target_type = fbuiltins.BUILTINS[node_kind] + source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] + if target_type is bool and source_type is not bool: + return im.literal(str(bool(source_type(node.args[0].value))), "bool") + return im.literal(str(node.args[0].value), node_kind) + raise FieldOperatorLoweringError( + f"Encountered a type cast, which is not supported: {node}." + ) - # def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - # return self._make_literal(node.value, node.type) + def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: + # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; + # the following constructs work if they are removed by inlining. + if isinstance(type_, ts.TupleType): + raise NotImplementedError("TODO") + return im.make_tuple( + *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) + ) + elif isinstance(type_, ts.ScalarType): + typename = type_.kind.name.lower() + return im.literal(str(val), typename) + raise ValueError(f"Unsupported literal type '{type_}'.") + + def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: + return self._make_literal(node.value, node.type) def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: lowered_args = [self.visit(arg, **kwargs) for arg in args] if any(type_info.contains_local_field(arg.type) for arg in args): - raise NotImplementedError("TODO: Local fields not supported") - # lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - # op = im.call("map_")(op) + lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + op = im.call("map_")(op) return im.op_as_fieldop(im.call(op))(*lowered_args) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index c2ded65fc1..cd84025535 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import typing -from typing import Callable, Iterable, Union, Optional +from typing import Callable, Iterable, Optional, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -255,6 +255,7 @@ def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) + def as_fieldop(stencil, domain=None): """Creates a field_operator from a stencil.""" args = [stencil] @@ -369,6 +370,18 @@ def lifted_neighbors(offset, it) -> itir.Expr: return lift(lambda_("it")(neighbors(offset, "it")))(it) +def as_fieldop_neighbors(offset, it) -> itir.Expr: + """ + Create a fieldop for neighbors call. + + Examples + -------- + >>> str(as_fieldop_neighbors("off", "a")) + '(⇑(λ(it) → neighbors(offₒ, it)))(a)' + """ + return as_fieldop(lambda_("it")(neighbors(offset, "it")))(it) + + def promote_to_const_iterator(expr: str | itir.Expr) -> itir.Expr: """ Create a lifted nullary lambda that captures `expr`. @@ -406,6 +419,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl + def op_as_fieldop( op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None ) -> Callable[..., itir.FunCall]: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 4950a97b3e..1ccdd85fbe 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -27,12 +27,12 @@ from gt4py.next import float32, float64, int32, int64, neighbor_sum from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering +from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.type_system import type_specifications as ts, type_translation IDim = gtx.Dimension("IDim") @@ -54,8 +54,6 @@ def debug_itir(tree): debug(format_python_source(EmbeddedDSL.apply(tree))) - - def test_return(): def copy_field(inp: gtx.Field[[TDim], float64]): return inp @@ -67,96 +65,85 @@ def copy_field(inp: gtx.Field[[TDim], float64]): assert lowered.expr == im.ref("inp") -# def test_scalar_arg(): -# def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: -# return alpha * bar - -# parsed = FieldOperatorParser.apply_to_function(scalar_arg) -# lowered = FieldOperatorLowering.apply(parsed) - -# reference = im.op_as_fieldop("multiplies")("alpha", "bar") # no difference to non-scalar arg - -# assert lowered.expr == reference - - -# def test_multicopy(): -# def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): -# return inp1, inp2 +def test_scalar_arg(): + def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: + return alpha * bar -# parsed = FieldOperatorParser.apply_to_function(multicopy) -# lowered = FieldOperatorLowering.apply(parsed) + # TODO document that scalar arguments of `as_fieldop(stencil)` are promoted to 0-d fields + parsed = FieldOperatorParser.apply_to_function(scalar_arg) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.make_tuple("inp1", "inp2") + reference = im.op_as_fieldop("multiplies")("alpha", "bar") # no difference to non-scalar arg -# assert lowered.expr == reference + assert lowered.expr == reference -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 +def test_multicopy(): + def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): + return inp1, inp2 - parsed = FieldOperatorParser.apply_to_function(arithmetic) + parsed = FieldOperatorParser.apply_to_function(multicopy) lowered = FieldOperatorLowering.apply(parsed) - print(lowered) - reference = im.op_as_fieldop("plus")("inp1", "inp2") + reference = im.make_tuple("inp1", "inp2") assert lowered.expr == reference -# def test_shift(): -# Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) +def test_cartesian_shift(): + Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) -# def shift_by_one(inp: gtx.Field[[IDim], float64]): -# return inp(Ioff[1]) + def foo(inp: gtx.Field[[IDim], float64]): + return inp(Ioff[1]) -# parsed = FieldOperatorParser.apply_to_function(shift_by_one) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") + reference = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") -# assert lowered.expr == reference + assert lowered.expr == reference -# def test_negative_shift(): -# Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) +def test_negative_cartesian_shift(): + Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) -# def shift_by_one(inp: gtx.Field[[IDim], float64]): -# return inp(Ioff[-1]) + def foo(inp: gtx.Field[[IDim], float64]): + return inp(Ioff[-1]) -# parsed = FieldOperatorParser.apply_to_function(shift_by_one) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") + reference = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") -# assert lowered.expr == reference + assert lowered.expr == reference -# def test_temp_assignment(): -# def copy_field(inp: gtx.Field[[TDim], float64]): -# tmp = inp -# inp = tmp -# tmp2 = inp -# return tmp2 +def test_temp_assignment(): + def copy_field(inp: gtx.Field[[TDim], float64]): + tmp = inp + inp = tmp + tmp2 = inp + return tmp2 -# parsed = FieldOperatorParser.apply_to_function(copy_field) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(copy_field) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.let(ssa.unique_name("tmp", 0), "inp")( -# im.let( -# ssa.unique_name("inp", 0), -# ssa.unique_name("tmp", 0), -# )( -# im.let( -# ssa.unique_name("tmp2", 0), -# ssa.unique_name("inp", 0), -# )(ssa.unique_name("tmp2", 0)) -# ) -# ) + reference = im.let(ssa.unique_name("tmp", 0), "inp")( + im.let( + ssa.unique_name("inp", 0), + ssa.unique_name("tmp", 0), + )( + im.let( + ssa.unique_name("tmp2", 0), + ssa.unique_name("inp", 0), + )(ssa.unique_name("tmp2", 0)) + ) + ) -# assert lowered.expr == reference + assert lowered.expr == reference +# TODO (introduce neg/pos) # def test_unary_ops(): # def unary(inp: gtx.Field[[TDim], float64]): # tmp = +inp @@ -179,39 +166,40 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 # ), # )(ssa.unique_name("tmp", 1)) # ) +# print(reference) # assert lowered.expr == reference -# def test_unpacking(): -# """Unpacking assigns should get separated.""" +def test_unpacking(): + """Unpacking assigns should get separated.""" -# def unpacking( -# inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] -# ) -> gtx.Field[[TDim], float64]: -# tmp1, tmp2 = inp1, inp2 # noqa -# return tmp1 + def unpacking( + inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] + ) -> gtx.Field[[TDim], float64]: + tmp1, tmp2 = inp1, inp2 # noqa + return tmp1 -# parsed = FieldOperatorParser.apply_to_function(unpacking) -# lowered = FieldOperatorLowering.apply(parsed) - -# tuple_expr = im.make_tuple("inp1", "inp2") -# tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") -# tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") + parsed = FieldOperatorParser.apply_to_function(unpacking) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.let("__tuple_tmp_0", tuple_expr)( -# im.let( -# ssa.unique_name("tmp1", 0), -# tuple_access_0, -# )( -# im.let( -# ssa.unique_name("tmp2", 0), -# tuple_access_1, -# )(ssa.unique_name("tmp1", 0)) -# ) -# ) + tuple_expr = im.make_tuple("inp1", "inp2") + tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") + tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") + + reference = im.let("__tuple_tmp_0", tuple_expr)( + im.let( + ssa.unique_name("tmp1", 0), + tuple_access_0, + )( + im.let( + ssa.unique_name("tmp2", 0), + tuple_access_1, + )(ssa.unique_name("tmp1", 0)) + ) + ) -# assert lowered.expr == reference + assert lowered.expr == reference # def test_annotated_assignment(): @@ -229,360 +217,372 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 # assert lowered.expr == reference -# def test_call(): -# # create something that appears to the lowering like a field operator. -# # we could also create an actual field operator, but we want to avoid -# # using such heavy constructs for testing the lowering. -# field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) -# identity = SimpleNamespace( -# __gt_type__=lambda: ts_ffront.FieldOperatorType( -# definition=ts.FunctionType( -# pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type -# ) -# ) -# ) +def test_call(): + # create something that appears to the lowering like a field operator. + # we could also create an actual field operator, but we want to avoid + # using such heavy constructs for testing the lowering. + field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) + identity = SimpleNamespace( + __gt_type__=lambda: ts_ffront.FieldOperatorType( + definition=ts.FunctionType( + pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type + ) + ) + ) -# def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: -# return identity(inp) + def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: + return identity(inp) -# parsed = FieldOperatorParser.apply_to_function(call) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(call) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.call("identity")("inp") + reference = im.call("identity")("inp") -# assert lowered.expr == reference + assert lowered.expr == reference -# def test_temp_tuple(): -# """Returning a temp tuple should work.""" +def test_return_constructed_tuple(): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): + return a, b -# def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): -# tmp = a, b -# return tmp + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(temp_tuple) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.make_tuple("a", "b") -# tuple_expr = im.make_tuple("a", "b") -# reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) + assert lowered.expr == reference -# assert lowered.expr == reference +def test_fieldop_with_tuple_arg(): + def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): + return a[0] -# def test_unary_not(): -# def unary_not(cond: gtx.Field[[TDim], "bool"]): -# return not cond + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(unary_not) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.tuple_get(0, "a") -# reference = im.promote_to_lifted_stencil("not_")("cond") + assert lowered.expr == reference -# assert lowered.expr == reference +def test_unary_not(): + def unary_not(cond: gtx.Field[[TDim], "bool"]): + return not cond -# def test_binary_plus(): -# def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a + b + parsed = FieldOperatorParser.apply_to_function(unary_not) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(plus) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.op_as_fieldop("not_")("cond") -# reference = im.promote_to_lifted_stencil("plus")("a", "b") + assert lowered.expr == reference -# assert lowered.expr == reference +def test_binary_plus(): + def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a + b -# def test_add_scalar_literal_to_field(): -# def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: -# return 2.0 + a + parsed = FieldOperatorParser.apply_to_function(plus) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.op_as_fieldop("plus")("a", "b") -# reference = im.promote_to_lifted_stencil("plus")( -# im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" -# ) + assert lowered.expr == reference -# assert lowered.expr == reference +def test_add_scalar_literal_to_field(): + def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + return 2.0 + a -# def test_add_scalar_literals(): -# def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: -# tmp = int32(1) + int32("1") -# return a + tmp + parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.op_as_fieldop("plus")(im.literal("2.0", "float64"), "a") -# reference = im.let( -# ssa.unique_name("tmp", 0), -# im.promote_to_lifted_stencil("plus")( -# im.promote_to_const_iterator(im.literal("1", "int32")), -# im.promote_to_const_iterator(im.literal("1", "int32")), -# ), -# )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) + assert lowered.expr == reference -# assert lowered.expr == reference +def test_add_scalar_literals(): + def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: + tmp = int32(1) + int32("1") + return a + tmp -# def test_binary_mult(): -# def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a * b + parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(mult) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.let( + ssa.unique_name("tmp", 0), + im.op_as_fieldop("plus")( + im.literal("1", "int32"), + im.literal("1", "int32"), + ), + )(im.op_as_fieldop("plus")("a", ssa.unique_name("tmp", 0))) -# reference = im.promote_to_lifted_stencil("multiplies")("a", "b") + assert lowered.expr == reference -# assert lowered.expr == reference +def test_binary_mult(): + def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a * b -# def test_binary_minus(): -# def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a - b + parsed = FieldOperatorParser.apply_to_function(mult) + lowered = FieldOperatorLowering.apply(parsed) -# parsed = FieldOperatorParser.apply_to_function(minus) -# lowered = FieldOperatorLowering.apply(parsed) + reference = im.op_as_fieldop("multiplies")("a", "b") + + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("minus")("a", "b") -# assert lowered.expr == reference +def test_binary_minus(): + def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a - b + parsed = FieldOperatorParser.apply_to_function(minus) + lowered = FieldOperatorLowering.apply(parsed) -# def test_binary_div(): -# def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a / b + reference = im.op_as_fieldop("minus")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(division) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("divides")("a", "b") -# assert lowered.expr == reference +def test_binary_div(): + def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a / b + parsed = FieldOperatorParser.apply_to_function(division) + lowered = FieldOperatorLowering.apply(parsed) -# def test_binary_and(): -# def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): -# return a & b + reference = im.op_as_fieldop("divides")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(bit_and) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("and_")("a", "b") -# assert lowered.expr == reference +def test_binary_and(): + def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): + return a & b + parsed = FieldOperatorParser.apply_to_function(bit_and) + lowered = FieldOperatorLowering.apply(parsed) -# def test_scalar_and(): -# def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: -# return a & False + reference = im.op_as_fieldop("and_")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(scalar_and) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("and_")( -# "a", im.promote_to_const_iterator(im.literal("False", "bool")) -# ) -# assert lowered.expr == reference +def test_scalar_and(): + def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: + return a & False + parsed = FieldOperatorParser.apply_to_function(scalar_and) + lowered = FieldOperatorLowering.apply(parsed) -# def test_binary_or(): -# def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): -# return a | b + reference = im.op_as_fieldop("and_")("a", im.literal("False", "bool")) -# parsed = FieldOperatorParser.apply_to_function(bit_or) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("or_")("a", "b") -# assert lowered.expr == reference +def test_binary_or(): + def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): + return a | b + parsed = FieldOperatorParser.apply_to_function(bit_or) + lowered = FieldOperatorLowering.apply(parsed) -# def test_compare_scalars(): -# def comp_scalars() -> bool: -# return 3 > 4 + reference = im.op_as_fieldop("or_")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(comp_scalars) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("greater")( -# im.promote_to_const_iterator(im.literal("3", "int32")), -# im.promote_to_const_iterator(im.literal("4", "int32")), -# ) -# assert lowered.expr == reference +def test_compare_scalars(): + def comp_scalars() -> bool: + return 3 > 4 + parsed = FieldOperatorParser.apply_to_function(comp_scalars) + lowered = FieldOperatorLowering.apply(parsed) -# def test_compare_gt(): -# def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a > b + reference = im.op_as_fieldop("greater")( + im.literal("3", "int32"), + im.literal("4", "int32"), + ) -# parsed = FieldOperatorParser.apply_to_function(comp_gt) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("greater")("a", "b") -# assert lowered.expr == reference +def test_compare_gt(): + def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a > b + parsed = FieldOperatorParser.apply_to_function(comp_gt) + lowered = FieldOperatorLowering.apply(parsed) -# def test_compare_lt(): -# def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): -# return a < b + reference = im.op_as_fieldop("greater")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(comp_lt) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("less")("a", "b") -# assert lowered.expr == reference +def test_compare_lt(): + def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return a < b + parsed = FieldOperatorParser.apply_to_function(comp_lt) + lowered = FieldOperatorLowering.apply(parsed) -# def test_compare_eq(): -# def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): -# return a == b + reference = im.op_as_fieldop("less")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(comp_eq) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("eq")("a", "b") -# assert lowered.expr == reference +def test_compare_eq(): + def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): + return a == b + parsed = FieldOperatorParser.apply_to_function(comp_eq) + lowered = FieldOperatorLowering.apply(parsed) -# def test_compare_chain(): -# def compare_chain( -# a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] -# ) -> gtx.Field[[IDim], bool]: -# return a > b > c + reference = im.op_as_fieldop("eq")("a", "b") -# parsed = FieldOperatorParser.apply_to_function(compare_chain) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil("and_")( -# im.promote_to_lifted_stencil("greater")("a", "b"), -# im.promote_to_lifted_stencil("greater")("b", "c"), -# ) -# assert lowered.expr == reference +def test_compare_chain(): + def compare_chain( + a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] + ) -> gtx.Field[[IDim], bool]: + return a > b > c + parsed = FieldOperatorParser.apply_to_function(compare_chain) + lowered = FieldOperatorLowering.apply(parsed) -# def test_reduction_lowering_simple(): -# def reduction(edge_f: gtx.Field[[Edge], float64]): -# return neighbor_sum(edge_f(V2E), axis=V2EDim) + reference = im.op_as_fieldop("and_")( + im.op_as_fieldop("greater")("a", "b"), + im.op_as_fieldop("greater")("b", "c"), + ) -# parsed = FieldOperatorParser.apply_to_function(reduction) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# reference = im.promote_to_lifted_stencil( -# im.call( -# im.call("reduce")( -# "plus", -# im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), -# ) -# ) -# )(im.lifted_neighbors("V2E", "edge_f")) -# assert lowered.expr == reference +def test_premap_to_local_field(): + def foo(edge_f: gtx.Field[gtx.Dims[Edge], float64]): + return edge_f(V2E) + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# def test_reduction_lowering_expr(): -# def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): -# e1_nbh = e1(V2E) -# return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) + reference = im.as_fieldop_neighbors("V2E", "edge_f") -# parsed = FieldOperatorParser.apply_to_function(reduction) -# lowered = FieldOperatorLowering.apply(parsed) + assert lowered.expr == reference -# mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( -# im.promote_to_lifted_stencil("make_const_list")( -# im.promote_to_const_iterator(im.literal("1.1", "float64")) -# ), -# im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), -# ) -# reference = im.let( -# ssa.unique_name("e1_nbh", 0), -# im.lifted_neighbors("V2E", "e1"), -# )( -# im.promote_to_lifted_stencil( -# im.call( -# im.call("reduce")( -# "plus", -# im.deref( -# im.promote_to_const_iterator(im.literal(value="0", typename="float64")) -# ), -# ) -# ) -# )(mapped) -# ) +def test_reduction_lowering_simple(): + def reduction(edge_f: gtx.Field[[Edge], float64]): + return neighbor_sum(edge_f(V2E), axis=V2EDim) -# assert lowered.expr == reference + parsed = FieldOperatorParser.apply_to_function(reduction) + lowered = FieldOperatorLowering.apply(parsed) + reference = im.op_as_fieldop( + im.call( + im.call("reduce")( + "plus", + im.literal(value="0", typename="float64"), + ) + ) + )(im.as_fieldop_neighbors("V2E", "edge_f")) -# def test_builtin_int_constructors(): -# def int_constrs() -> tuple[int32, int32, int64, int32, int64]: -# return 1, int32(1), int64(1), int32("1"), int64("1") + assert lowered.expr == reference -# parsed = FieldOperatorParser.apply_to_function(int_constrs) -# lowered = FieldOperatorLowering.apply(parsed) -# reference = im.make_tuple( -# im.promote_to_const_iterator(im.literal("1", "int32")), -# im.promote_to_const_iterator(im.literal("1", "int32")), -# im.promote_to_const_iterator(im.literal("1", "int64")), -# im.promote_to_const_iterator(im.literal("1", "int32")), -# im.promote_to_const_iterator(im.literal("1", "int64")), -# ) +def test_reduction_lowering_expr(): + def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): + e1_nbh = e1(V2E) + return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) -# assert lowered.expr == reference + parsed = FieldOperatorParser.apply_to_function(reduction) + lowered = FieldOperatorLowering.apply(parsed) + mapped = im.op_as_fieldop(im.map_("multiplies"))( + im.op_as_fieldop("make_const_list")(im.literal("1.1", "float64")), + im.op_as_fieldop(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), + ) + + reference = im.let( + ssa.unique_name("e1_nbh", 0), + im.as_fieldop_neighbors("V2E", "e1"), + )( + im.op_as_fieldop( + im.call( + im.call("reduce")( + "plus", + im.literal(value="0", typename="float64"), + ) + ) + )(mapped) + ) -# def test_builtin_float_constructors(): -# def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: -# return ( -# 0.1, -# float(0.1), -# float32(0.1), -# float64(0.1), -# float(".1"), -# float32(".1"), -# float64(".1"), -# ) + assert lowered.expr == reference -# parsed = FieldOperatorParser.apply_to_function(float_constrs) -# lowered = FieldOperatorLowering.apply(parsed) -# reference = im.make_tuple( -# im.promote_to_const_iterator(im.literal("0.1", "float64")), -# im.promote_to_const_iterator(im.literal("0.1", "float64")), -# im.promote_to_const_iterator(im.literal("0.1", "float32")), -# im.promote_to_const_iterator(im.literal("0.1", "float64")), -# im.promote_to_const_iterator(im.literal(".1", "float64")), -# im.promote_to_const_iterator(im.literal(".1", "float32")), -# im.promote_to_const_iterator(im.literal(".1", "float64")), -# ) +def test_builtin_int_constructors(): + def int_constrs() -> tuple[int32, int32, int64, int32, int64]: + return 1, int32(1), int64(1), int32("1"), int64("1") -# assert lowered.expr == reference + parsed = FieldOperatorParser.apply_to_function(int_constrs) + lowered = FieldOperatorLowering.apply(parsed) + reference = im.make_tuple( + im.literal("1", "int32"), + im.literal("1", "int32"), + im.literal("1", "int64"), + im.literal("1", "int32"), + im.literal("1", "int64"), + ) -# def test_builtin_bool_constructors(): -# def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: -# return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") + assert lowered.expr == reference -# parsed = FieldOperatorParser.apply_to_function(bool_constrs) -# lowered = FieldOperatorLowering.apply(parsed) -# reference = im.make_tuple( -# im.promote_to_const_iterator(im.literal(str(True), "bool")), -# im.promote_to_const_iterator(im.literal(str(False), "bool")), -# im.promote_to_const_iterator(im.literal(str(True), "bool")), -# im.promote_to_const_iterator(im.literal(str(False), "bool")), -# im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), -# im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), -# im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), -# im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), -# ) +def test_builtin_float_constructors(): + def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: + return ( + 0.1, + float(0.1), + float32(0.1), + float64(0.1), + float(".1"), + float32(".1"), + float64(".1"), + ) -# assert lowered.expr == reference + parsed = FieldOperatorParser.apply_to_function(float_constrs) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.make_tuple( + im.literal("0.1", "float64"), + im.literal("0.1", "float64"), + im.literal("0.1", "float32"), + im.literal("0.1", "float64"), + im.literal(".1", "float64"), + im.literal(".1", "float32"), + im.literal(".1", "float64"), + ) + + assert lowered.expr == reference + + +def test_builtin_bool_constructors(): + def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: + return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") + + parsed = FieldOperatorParser.apply_to_function(bool_constrs) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.make_tuple( + im.literal(str(True), "bool"), + im.literal(str(False), "bool"), + im.literal(str(True), "bool"), + im.literal(str(False), "bool"), + im.literal(str(bool(0)), "bool"), + im.literal(str(bool(5)), "bool"), + im.literal(str(bool("True")), "bool"), + im.literal(str(bool("False")), "bool"), + ) + + assert lowered.expr == reference From 01a6e07dda97f50a62bd95940358cdd12a1e9372 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 5 Jul 2024 08:59:46 +0200 Subject: [PATCH 04/29] test cleanup --- .../ffront_tests/test_foast_to_gtir.py | 147 ++++++++---------- 1 file changed, 61 insertions(+), 86 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 1ccdd85fbe..ba6017bae0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -44,33 +44,23 @@ TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors.runners.roundtrip import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - def test_return(): - def copy_field(inp: gtx.Field[[TDim], float64]): + def foo(inp: gtx.Field[[TDim], float64]): return inp - parsed = FieldOperatorParser.apply_to_function(copy_field) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - assert lowered.id == "copy_field" + assert lowered.id == "foo" assert lowered.expr == im.ref("inp") def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: + def foo(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: return alpha * bar # TODO document that scalar arguments of `as_fieldop(stencil)` are promoted to 0-d fields - parsed = FieldOperatorParser.apply_to_function(scalar_arg) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("multiplies")("alpha", "bar") # no difference to non-scalar arg @@ -79,10 +69,10 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): + def foo(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): return inp1, inp2 - parsed = FieldOperatorParser.apply_to_function(multicopy) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.make_tuple("inp1", "inp2") @@ -90,7 +80,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] assert lowered.expr == reference -def test_cartesian_shift(): +def test_premap(): Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) def foo(inp: gtx.Field[[IDim], float64]): @@ -104,28 +94,14 @@ def foo(inp: gtx.Field[[IDim], float64]): assert lowered.expr == reference -def test_negative_cartesian_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def foo(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(foo) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): + def foo(inp: gtx.Field[[TDim], float64]): tmp = inp inp = tmp tmp2 = inp return tmp2 - parsed = FieldOperatorParser.apply_to_function(copy_field) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.let(ssa.unique_name("tmp", 0), "inp")( @@ -166,7 +142,6 @@ def copy_field(inp: gtx.Field[[TDim], float64]): # ), # )(ssa.unique_name("tmp", 1)) # ) -# print(reference) # assert lowered.expr == reference @@ -174,13 +149,13 @@ def copy_field(inp: gtx.Field[[TDim], float64]): def test_unpacking(): """Unpacking assigns should get separated.""" - def unpacking( + def foo( inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] ) -> gtx.Field[[TDim], float64]: tmp1, tmp2 = inp1, inp2 # noqa return tmp1 - parsed = FieldOperatorParser.apply_to_function(unpacking) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) tuple_expr = im.make_tuple("inp1", "inp2") @@ -202,19 +177,19 @@ def unpacking( assert lowered.expr == reference -# def test_annotated_assignment(): -# pytest.xfail("Annotated assignments are not properly supported at the moment.") +def test_annotated_assignment(): + pytest.xfail("Annotated assignments are not properly supported at the moment.") -# def copy_field(inp: gtx.Field[[TDim], float64]): -# tmp: gtx.Field[[TDim], float64] = inp -# return tmp + def foo(inp: gtx.Field[[TDim], float64]): + tmp: gtx.Field[[TDim], float64] = inp + return tmp -# parsed = FieldOperatorParser.apply_to_function(copy_field) -# lowered = FieldOperatorLowering.apply(parsed) + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) -# reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) + reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) -# assert lowered.expr == reference + assert lowered.expr == reference def test_call(): @@ -230,10 +205,10 @@ def test_call(): ) ) - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: + def foo(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: return identity(inp) - parsed = FieldOperatorParser.apply_to_function(call) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.call("identity")("inp") @@ -266,10 +241,10 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): + def foo(cond: gtx.Field[[TDim], "bool"]): return not cond - parsed = FieldOperatorParser.apply_to_function(unary_not) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("not_")("cond") @@ -278,10 +253,10 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]): def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a + b - parsed = FieldOperatorParser.apply_to_function(plus) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("plus")("a", "b") @@ -290,10 +265,10 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + def foo(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: return 2.0 + a - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("plus")(im.literal("2.0", "float64"), "a") @@ -302,11 +277,11 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6 def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: + def foo(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: tmp = int32(1) + int32("1") return a + tmp - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.let( @@ -321,10 +296,10 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3 def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a * b - parsed = FieldOperatorParser.apply_to_function(mult) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("multiplies")("a", "b") @@ -333,10 +308,10 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a - b - parsed = FieldOperatorParser.apply_to_function(minus) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("minus")("a", "b") @@ -345,10 +320,10 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a / b - parsed = FieldOperatorParser.apply_to_function(division) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("divides")("a", "b") @@ -357,10 +332,10 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): + def foo(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a & b - parsed = FieldOperatorParser.apply_to_function(bit_and) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("and_")("a", "b") @@ -369,10 +344,10 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: + def foo(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: return a & False - parsed = FieldOperatorParser.apply_to_function(scalar_and) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("and_")("a", im.literal("False", "bool")) @@ -381,10 +356,10 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): + def foo(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a | b - parsed = FieldOperatorParser.apply_to_function(bit_or) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("or_")("a", "b") @@ -393,10 +368,10 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): def test_compare_scalars(): - def comp_scalars() -> bool: + def foo() -> bool: return 3 > 4 - parsed = FieldOperatorParser.apply_to_function(comp_scalars) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("greater")( @@ -408,10 +383,10 @@ def comp_scalars() -> bool: def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a > b - parsed = FieldOperatorParser.apply_to_function(comp_gt) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("greater")("a", "b") @@ -432,10 +407,10 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): + def foo(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): return a == b - parsed = FieldOperatorParser.apply_to_function(comp_eq) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("eq")("a", "b") @@ -444,12 +419,12 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): def test_compare_chain(): - def compare_chain( + def foo( a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] ) -> gtx.Field[[IDim], bool]: return a > b > c - parsed = FieldOperatorParser.apply_to_function(compare_chain) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("and_")( @@ -473,10 +448,10 @@ def foo(edge_f: gtx.Field[gtx.Dims[Edge], float64]): def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): + def foo(edge_f: gtx.Field[[Edge], float64]): return neighbor_sum(edge_f(V2E), axis=V2EDim) - parsed = FieldOperatorParser.apply_to_function(reduction) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( @@ -492,11 +467,11 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): + def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): e1_nbh = e1(V2E) return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - parsed = FieldOperatorParser.apply_to_function(reduction) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) mapped = im.op_as_fieldop(im.map_("multiplies"))( @@ -522,10 +497,10 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: + def foo() -> tuple[int32, int32, int64, int32, int64]: return 1, int32(1), int64(1), int32("1"), int64("1") - parsed = FieldOperatorParser.apply_to_function(int_constrs) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.make_tuple( @@ -540,7 +515,7 @@ def int_constrs() -> tuple[int32, int32, int64, int32, int64]: def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: + def foo() -> tuple[float, float, float32, float64, float, float32, float64]: return ( 0.1, float(0.1), @@ -551,7 +526,7 @@ def float_constrs() -> tuple[float, float, float32, float64, float, float32, flo float64(".1"), ) - parsed = FieldOperatorParser.apply_to_function(float_constrs) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.make_tuple( @@ -568,10 +543,10 @@ def float_constrs() -> tuple[float, float, float32, float64, float, float32, flo def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: + def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - parsed = FieldOperatorParser.apply_to_function(bool_constrs) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.make_tuple( From b949ebdfba369c70d02dbfcf19dcc9e1842b395a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 5 Jul 2024 13:13:06 +0200 Subject: [PATCH 05/29] add to_gtir path to past_to_itir --- src/gt4py/next/ffront/past_to_itir.py | 100 ++++++++++++++++---------- 1 file changed, 64 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index b2c038583f..d459b8d518 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -114,6 +114,7 @@ def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") +@dataclasses.dataclass class ProgramLowering( traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator ): @@ -150,6 +151,9 @@ class ProgramLowering( [Sym(id=SymbolName('inp')), Sym(id=SymbolName('out')), Sym(id=SymbolName('__inp_size_0')), Sym(id=SymbolName('__out_size_0'))] """ + grid_type: common.GridType + to_gtir: bool = False # TODO(havogt): remove after refactoring to GTIR + # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -159,11 +163,11 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, + to_gtir: bool = False, ) -> itir.FencilDefinition: - return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) - - def __init__(self, grid_type: common.GridType): - self.grid_type = grid_type + return cls(grid_type=grid_type, to_gtir=to_gtir).visit( + node, function_definitions=function_definitions + ) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -192,7 +196,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition: + ) -> itir.FencilDefinition | itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -203,13 +207,28 @@ def visit_Program( if any("domain" not in body_entry.kwargs for body_entry in node.body): params = params + self._gen_size_params_from_program(node) - closures: list[itir.StencilClosure] = [] + closures_or_set_ats: list[ + itir.StencilClosure | itir.SetAt + ] = [] # TODO(havogt): fix naming after refactoring to GTIR for stmt in node.body: - closures.append(self._visit_stencil_call(stmt, **kwargs)) - - return itir.FencilDefinition( - id=node.id, function_definitions=function_definitions, params=params, closures=closures - ) + closures_or_set_ats.append(self._visit_stencil_call(stmt, **kwargs)) + if self.to_gtir: + assert all(isinstance(s, itir.SetAt) for s in closures_or_set_ats) + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=closures_or_set_ats, + ) + else: + assert all(isinstance(s, itir.StencilClosure) for s in closures_or_set_ats) + return itir.FencilDefinition( + id=node.id, + function_definitions=function_definitions, + params=params, + closures=closures_or_set_ats, + ) def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) @@ -229,35 +248,44 @@ def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClo lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) + if self.to_gtir: + return itir.SetAt( + expr=im.call(node.func.id)(*lowered_args, *lowered_kwargs.values()), + domain=lowered_domain, + target=output, + ) + else: + stencil_params = [] + stencil_args: list[itir.Expr] = [] + for i, arg in enumerate([*args, *node_kwargs]): + stencil_params.append(f"__stencil_arg{i}") + if isinstance(arg.type, ts.TupleType): + # convert into tuple of iterators + stencil_args.append( + lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) + ) + else: + stencil_args.append(im.ref(f"__stencil_arg{i}")) + + if isinstance(node.func.type, ts_ffront.ScanOperatorType): + # scan operators return an iterator of tuples, just deref directly + stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) + # field operators return a tuple of iterators, deref element-wise + stencil_body = lowering_utils.process_elements( + im.deref, + im.call(node.func.id)(*stencil_args), + node.func.type.definition.returns, + ) - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, im.call(node.func.id)(*stencil_args), node.func.type.definition.returns + return itir.StencilClosure( + domain=lowered_domain, + stencil=im.lambda_(*stencil_params)(stencil_body), + inputs=[*lowered_args, *lowered_kwargs.values()], + output=output, + location=node.location, ) - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], From be21b78fcc02ce1fb03e2edb02fc893a2a77d107 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 10 Jul 2024 10:17:57 +0200 Subject: [PATCH 06/29] update past_to_gtir_test --- src/gt4py/next/ffront/past_to_itir.py | 98 ++++---- .../ffront_tests/test_past_to_gtir.py | 217 ++++++++++++++++++ 2 files changed, 272 insertions(+), 43 deletions(-) create mode 100644 tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index d459b8d518..7162e05256 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -207,30 +207,25 @@ def visit_Program( if any("domain" not in body_entry.kwargs for body_entry in node.body): params = params + self._gen_size_params_from_program(node) - closures_or_set_ats: list[ - itir.StencilClosure | itir.SetAt - ] = [] # TODO(havogt): fix naming after refactoring to GTIR - for stmt in node.body: - closures_or_set_ats.append(self._visit_stencil_call(stmt, **kwargs)) if self.to_gtir: - assert all(isinstance(s, itir.SetAt) for s in closures_or_set_ats) + set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] return itir.Program( id=node.id, function_definitions=function_definitions, params=params, declarations=[], - body=closures_or_set_ats, + body=set_ats, ) else: - assert all(isinstance(s, itir.StencilClosure) for s in closures_or_set_ats) + closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] return itir.FencilDefinition( id=node.id, function_definitions=function_definitions, params=params, - closures=closures_or_set_ats, + closures=closures, ) - def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: + def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -248,44 +243,61 @@ def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClo lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - if self.to_gtir: - return itir.SetAt( - expr=im.call(node.func.id)(*lowered_args, *lowered_kwargs.values()), - domain=lowered_domain, - target=output, - ) - else: - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) + return itir.SetAt( + expr=im.call(node.func.id)(*lowered_args, *lowered_kwargs.values()), + domain=lowered_domain, + target=output, + ) - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, + def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: + assert isinstance(node.kwargs["out"].type, ts.TypeSpec) + assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) + + node_kwargs = {**node.kwargs} + domain = node_kwargs.pop("domain", None) + output, lowered_domain = self._visit_stencil_call_out_arg( + node_kwargs.pop("out"), domain, **kwargs + ) + + assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) + + args, node_kwargs = type_info.canonicalize_arguments( + node.func.type, node.args, node_kwargs, use_signature_ordering=True + ) + + lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) + + stencil_params = [] + stencil_args: list[itir.Expr] = [] + for i, arg in enumerate([*args, *node_kwargs]): + stencil_params.append(f"__stencil_arg{i}") + if isinstance(arg.type, ts.TupleType): + # convert into tuple of iterators + stencil_args.append( + lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) ) + else: + stencil_args.append(im.ref(f"__stencil_arg{i}")) - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, + if isinstance(node.func.type, ts_ffront.ScanOperatorType): + # scan operators return an iterator of tuples, just deref directly + stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) + else: + # field operators return a tuple of iterators, deref element-wise + stencil_body = lowering_utils.process_elements( + im.deref, + im.call(node.func.id)(*stencil_args), + node.func.type.definition.returns, ) + return itir.StencilClosure( + domain=lowered_domain, + stencil=im.lambda_(*stencil_params)(stencil_body), + inputs=[*lowered_args, *lowered_kwargs.values()], + output=output, + location=node.location, + ) + def _visit_slice_bound( self, slice_bound: Optional[past.Constant], diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py new file mode 100644 index 0000000000..4dacba9546 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -0,0 +1,217 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import re + +import pytest + +import gt4py.eve as eve +import gt4py.next as gtx +from gt4py.eve.pattern_matching import ObjectPattern as P +from gt4py.next import errors +from gt4py.next.ffront.func_to_past import ProgramParser +from gt4py.next.ffront.past_to_itir import ProgramLowering +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts + +from next_tests.past_common_fixtures import ( + IDim, + copy_program_def, + copy_restrict_program_def, + float64, + identity_def, + invalid_call_sig_program_def, +) + + +@pytest.fixture +def gtir_identity_fundef(): + return itir.FunctionDefinition( + id="identity", + params=[itir.Sym(id="x")], + expr=im.as_fieldop("deref")("x"), + ) + + +def test_copy_lowering(copy_program_def, gtir_identity_fundef): + past_node = ProgramParser.apply_to_function(copy_program_def) + itir_node = ProgramLowering.apply( + past_node, + function_definitions=[gtir_identity_fundef], + grid_type=gtx.GridType.CARTESIAN, + to_gtir=True, + ) + set_at_pattern = P( + itir.SetAt, + domain=P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), + args=[ + P(itir.AxisLiteral, value="IDim"), + P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), + P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), + ], + ) + ], + ), + expr=P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("identity")), + args=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], + ), + target=P(itir.SymRef, id=eve.SymbolRef("out")), + ) + program_pattern = P( + itir.Program, + id=eve.SymbolName("copy_program"), + params=[ + P(itir.Sym, id=eve.SymbolName("in_field")), + P(itir.Sym, id=eve.SymbolName("out")), + P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), + P(itir.Sym, id=eve.SymbolName("__out_size_0")), + ], + body=[set_at_pattern], + ) + + program_pattern.match(itir_node, raise_exception=True) + + +def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef): + past_node = ProgramParser.apply_to_function(copy_restrict_program_def) + itir_node = ProgramLowering.apply( + past_node, + function_definitions=[gtir_identity_fundef], + grid_type=gtx.GridType.CARTESIAN, + to_gtir=True, + ) + set_at_pattern = P( + itir.SetAt, + domain=P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), + args=[ + P(itir.AxisLiteral, value="IDim"), + P( + itir.Literal, + value="1", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), + P( + itir.Literal, + value="2", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), + ], + ) + ], + ), + ) + fencil_pattern = P( + itir.Program, + id=eve.SymbolName("copy_restrict_program"), + params=[ + P(itir.Sym, id=eve.SymbolName("in_field")), + P(itir.Sym, id=eve.SymbolName("out")), + P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), + P(itir.Sym, id=eve.SymbolName("__out_size_0")), + ], + body=[set_at_pattern], + ) + + fencil_pattern.match(itir_node, raise_exception=True) + + +def test_tuple_constructed_in_out_with_slicing(make_tuple_op): + def tuple_program( + inp: gtx.Field[[IDim], float64], + out1: gtx.Field[[IDim], float64], + out2: gtx.Field[[IDim], float64], + ): + make_tuple_op(inp, out=(out1[1:], out2[1:])) + + parsed = ProgramParser.apply_to_function(tuple_program) + ProgramLowering.apply( + parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True + ) + + +@pytest.mark.xfail( + reason="slicing is only allowed if all fields are sliced in the same way." +) # see ADR 10 +def test_tuple_constructed_in_out_with_slicing(make_tuple_op): + def tuple_program( + inp: gtx.Field[[IDim], float64], + out1: gtx.Field[[IDim], float64], + out2: gtx.Field[[IDim], float64], + ): + make_tuple_op(inp, out=(out1[1:], out2)) + + parsed = ProgramParser.apply_to_function(tuple_program) + ProgramLowering.apply( + parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True + ) + + +@pytest.mark.xfail +def test_inout_prohibited(identity_def): + identity = gtx.field_operator(identity_def) + + def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): + identity(inout_field, out=inout_field) + + with pytest.raises( + ValueError, match=(r"Call to function with field as input and output not allowed.") + ): + ProgramLowering.apply( + ProgramParser.apply_to_function(inout_field_program), + function_definitions=[], + grid_type=gtx.GridType.CARTESIAN, + ) + + +def test_invalid_call_sig_program(invalid_call_sig_program_def): + with pytest.raises(errors.DSLError) as exc_info: + ProgramLowering.apply( + ProgramParser.apply_to_function(invalid_call_sig_program_def), + function_definitions=[], + grid_type=gtx.GridType.CARTESIAN, + to_gtir=True, + ) + + assert exc_info.match("Invalid call to 'identity'") + # TODO(tehrengruber): re-enable again when call signature check doesn't return + # immediately after missing `out` argument + # assert ( + # re.search( + # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] + # ) + # is not None + # ) + assert ( + re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) + is not None + ) From be0011285ea735ba6aa6ccf964934166afa9bfc1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 10 Jul 2024 17:36:42 +0200 Subject: [PATCH 07/29] gtir_embedded backend --- src/gt4py/next/ffront/foast_to_gtir.py | 21 ++++++++++--------- src/gt4py/next/ffront/past_to_itir.py | 7 ++++++- src/gt4py/next/otf/stages.py | 2 +- .../codegens/gtfn/gtfn_module.py | 3 ++- .../runners/dace_iterator/workflow.py | 3 ++- .../program_processors/runners/roundtrip.py | 15 +++++++++++++ tests/next_tests/definitions.py | 1 + .../ffront_tests/ffront_test_utils.py | 1 + 8 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2177ae4369..3cfc308dd6 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -22,6 +22,7 @@ dialect_ast_enums, fbuiltins, field_operator_ast as foast, + stages as ffront_stages, type_specifications as ts_ffront, ) from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES @@ -31,8 +32,8 @@ from gt4py.next.type_system import type_info, type_specifications as ts -# def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: -# return FieldOperatorLowering.apply(inp.foast_node) +def foast_to_gtir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: + return FieldOperatorLowering.apply(inp.foast_node) def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: @@ -82,16 +83,16 @@ def visit_FunctionDefinition( id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) ) # `expr` is a lifted stencil - # def visit_FieldOperator( - # self, node: foast.FieldOperator, **kwargs: Any - # ) -> itir.FunctionDefinition: - # func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + def visit_FieldOperator( + self, node: foast.FieldOperator, **kwargs: Any + ) -> itir.FunctionDefinition: + func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - # new_body = func_definition.expr + new_body = func_definition.expr - # return itir.FunctionDefinition( - # id=func_definition.id, params=func_definition.params, expr=new_body - # ) + return itir.FunctionDefinition( + id=func_definition.id, params=func_definition.params, expr=new_body + ) # def visit_ScanOperator( # self, node: foast.ScanOperator, **kwargs: Any diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 7162e05256..63c7287bb2 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -39,6 +39,8 @@ @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): + to_gtir: bool = False + def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -54,7 +56,10 @@ def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] itir_program = ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + inp.past_node, + function_definitions=lowered_funcs, + grid_type=grid_type, + to_gtir=self.to_gtir, ) if config.DEBUG or "debug" in inp.kwargs: diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 5e2ac8aaca..cb0a614a52 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -34,7 +34,7 @@ class ProgramCall: """Iterator IR representaion of a program together with arguments to be passed to it.""" - program: itir.FencilDefinition + program: itir.FencilDefinition | itir.Program args: tuple[Any, ...] kwargs: dict[str, Any] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 384d74a6c2..85c3e0e75f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -234,7 +234,8 @@ def __call__( self, inp: stages.ProgramCall ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition = inp.program + program = inp.program + assert isinstance(program, itir.FencilDefinition) # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 1a7a36b8c5..2d7f2f4dcc 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -83,7 +83,8 @@ def __call__( self, inp: stages.ProgramCall ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition = inp.program + program = inp.program + assert isinstance(program, itir.FencilDefinition) arg_types = [tt.from_value(arg) for arg in inp.args] sdfg = self.generate_sdfg( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 50ff92e4c1..01046d7b27 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -27,6 +27,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config +from gt4py.next.ffront import foast_to_gtir, past_to_itir, stages as ffront_stages from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import fencil_to_program from gt4py.next.otf import stages, workflow @@ -276,3 +277,17 @@ class Params: with_temporaries = next_backend.Backend( executor=executor_with_temporaries, allocator=next_allocators.StandardCPUFieldBufferAllocator() ) + +gtir = next_backend.Backend( + executor=executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms_fop=next_backend.FieldopTransformWorkflow( + past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True), + foast_to_itir=workflow.CachedStep( + step=foast_to_gtir.foast_to_gtir, hash_function=ffront_stages.fingerprint_stage + ), + ), + transforms_prog=next_backend.ProgramTransformWorkflow( + past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True) + ), +) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 9bbeb02298..c75a006d93 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -55,6 +55,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): ) GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + GTIR_EMBEDDED = "gt4py.next.program_processors.runners.roundtrip.gtir" ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 840f6d6143..b2ad11783a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -68,6 +68,7 @@ def __call__(self, program, *args, **kwargs) -> None: @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, + next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, From 5d0d16ec2a53b115bdb99d14a5699c6274b4b797 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 5 Aug 2024 09:49:27 +0200 Subject: [PATCH 08/29] add broadcast, max_over, min_over --- src/gt4py/next/ffront/foast_to_gtir.py | 33 ++++----- .../ffront_tests/test_foast_to_gtir.py | 68 ++++++++++++++++++- 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3cfc308dd6..e76b0186ec 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -362,8 +362,8 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # _visit_concat_where = _visit_where - # def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - # return self.visit(node.args[0], **kwargs) + def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + return self.visit(node.args[0], **kwargs) # def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: # return self._map(self.visit(node.func, **kwargs), *node.args) @@ -381,17 +381,17 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - # def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # dtype = type_info.extract_dtype(node.type) - # min_value, _ = type_info.arithmetic_bounds(dtype) - # init_expr = self._make_literal(str(min_value), dtype) - # return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) + def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + dtype = type_info.extract_dtype(node.type) + min_value, _ = type_info.arithmetic_bounds(dtype) + init_expr = self._make_literal(str(min_value), dtype) + return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - # def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - # dtype = type_info.extract_dtype(node.type) - # _, max_value = type_info.arithmetic_bounds(dtype) - # init_expr = self._make_literal(str(max_value), dtype) - # return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) + def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: + dtype = type_info.extract_dtype(node.type) + _, max_value = type_info.arithmetic_bounds(dtype) + init_expr = self._make_literal(str(max_value), dtype) + return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: if isinstance(node.args[0], foast.Constant): @@ -406,14 +406,7 @@ def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: ) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - raise NotImplementedError("TODO") - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): + if isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() return im.literal(str(val), typename) raise ValueError(f"Unsupported literal type '{type_}'.") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index ba6017bae0..19b7743d52 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -24,7 +24,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum +from gt4py.next import float32, float64, int32, int64, neighbor_sum, broadcast, max_over, min_over from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering @@ -32,7 +32,8 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.type_system import type_specifications as ts, type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation, type_info +import numpy as np IDim = gtx.Dimension("IDim") @@ -42,6 +43,7 @@ V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +UDim = gtx.Dimension("UDim") # Meaningless dimension, used for tests. def test_return(): @@ -55,6 +57,17 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.expr == im.ref("inp") +def test_return_literal_tuple(): + def foo(): + return (1.0, True) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.make_tuple(im.literal_from_value(1.0), im.literal_from_value(True)) + + def test_scalar_arg(): def foo(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: return alpha * bar @@ -447,7 +460,7 @@ def foo(edge_f: gtx.Field[gtx.Dims[Edge], float64]): assert lowered.expr == reference -def test_reduction_lowering_simple(): +def test_reduction_lowering_neighbor_sum(): def foo(edge_f: gtx.Field[[Edge], float64]): return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -466,6 +479,44 @@ def foo(edge_f: gtx.Field[[Edge], float64]): assert lowered.expr == reference +def test_reduction_lowering_max_over(): + def foo(edge_f: gtx.Field[[Edge], float64]): + return max_over(edge_f(V2E), axis=V2EDim) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop( + im.call( + im.call("reduce")( + "maximum", + im.literal(value=str(np.finfo(np.float64).min), typename="float64"), + ) + ) + )(im.as_fieldop_neighbors("V2E", "edge_f")) + + assert lowered.expr == reference + + +def test_reduction_lowering_min_over(): + def foo(edge_f: gtx.Field[[Edge], float64]): + return min_over(edge_f(V2E), axis=V2EDim) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop( + im.call( + im.call("reduce")( + "minimum", + im.literal(value=str(np.finfo(np.float64).max), typename="float64"), + ) + ) + )(im.as_fieldop_neighbors("V2E", "edge_f")) + + assert lowered.expr == reference + + def test_reduction_lowering_expr(): def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): e1_nbh = e1(V2E) @@ -561,3 +612,14 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: ) assert lowered.expr == reference + + +def test_broadcast(): + def foo(inp: gtx.Field[[TDim], float64]): + return broadcast(inp, (UDim, TDim)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.ref("inp") From bca403f43a828400eb25332bd328f659799d4dce Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 5 Aug 2024 09:59:22 +0200 Subject: [PATCH 09/29] math builtins --- src/gt4py/next/ffront/foast_to_gtir.py | 10 +++--- .../ffront_tests/test_foast_to_gtir.py | 31 +++++++++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index e76b0186ec..a8967ef9ed 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -26,7 +26,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, TYPE_BUILTIN_NAMES +from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts @@ -301,8 +301,8 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: return self._visit_shift(node, **kwargs) - # elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - # return self._visit_math_built_in(node, **kwargs) + elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: + return self._visit_math_built_in(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in ( FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES ): @@ -365,8 +365,8 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self.visit(node.args[0], **kwargs) - # def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - # return self._map(self.visit(node.func, **kwargs), *node.args) + def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + return self._map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 19b7743d52..c617ed589f 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -21,19 +21,20 @@ from types import SimpleNamespace +import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum, broadcast, max_over, min_over +from gt4py.next import broadcast, float32, float64, int32, int64, max_over, min_over, neighbor_sum from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa +from gt4py.next.ffront.fbuiltins import exp, minimum from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.type_system import type_specifications as ts, type_translation, type_info -import numpy as np +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation IDim = gtx.Dimension("IDim") @@ -448,6 +449,30 @@ def foo( assert lowered.expr == reference +def test_unary_math_builtin(): + def foo(a: gtx.Field[[TDim], float64]): + return exp(a) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop("exp")("a") + + assert lowered.expr == reference + + +def test_binary_math_builtin(): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + return minimum(a, b) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop("minimum")("a", "b") + + assert lowered.expr == reference + + def test_premap_to_local_field(): def foo(edge_f: gtx.Field[gtx.Dims[Edge], float64]): return edge_f(V2E) From ce8ea5659ce8c598096775e96533b40482f63746 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 5 Aug 2024 10:59:36 +0200 Subject: [PATCH 10/29] cleanup merge conflict --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 86 ++++++------------- 1 file changed, 25 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0f00ac009d..02eead7fb3 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -256,12 +256,26 @@ def lift(expr): return call(call("lift")(expr)) -def as_fieldop(stencil, domain=None): - """Creates a field_operator from a stencil.""" - args = [stencil] - if domain is not None: - args.append(domain) - return call(call("as_fieldop")(*args)) +def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: + """ + Create an `as_fieldop` call. + Examples + -------- + >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) + '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' + """ + return call( + call("as_fieldop")( + *( + ( + expr, + domain, + ) + if domain + else (expr,) + ) + ) + ) class let: @@ -420,61 +434,6 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl -def op_as_fieldop( - op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None -) -> Callable[..., itir.FunCall]: - """ - Promotes a function `op` to a field_operator. - - `op` is a function from values to value. - - Returns: - A function from Fields to Field. - - Examples - -------- - >>> str(op_as_fieldop("op")("a", "b")) - '(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' - """ - if isinstance(op, (str, itir.SymRef, itir.Lambda)): - op = call(op) - - def _impl(*its: itir.Expr) -> itir.FunCall: - args = [ - f"__arg{i}" for i in range(len(its)) - ] # TODO: `op` must not contain `SymRef(id="__argX")` - return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its) - - return _impl - - -def map_(op): - """Create a `map_` call.""" - return call(call("map_")(op)) - - -def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: - """ - Create an `as_fieldop` call. - Examples - -------- - >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) - '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' - """ - return call( - call("as_fieldop")( - *( - ( - expr, - domain, - ) - if domain - else (expr,) - ) - ) - ) - - def op_as_fieldop( op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None ) -> Callable[..., itir.FunCall]: @@ -502,3 +461,8 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its) return _impl + + +def map_(op): + """Create a `map_` call.""" + return call(call("map_")(op)) From 17f891d11cf329d9fd3e8eefd4a115940a43b825 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 6 Aug 2024 15:36:22 +0200 Subject: [PATCH 11/29] add where lowering --- src/gt4py/eve/utils.py | 2 + src/gt4py/next/ffront/foast_to_gtir.py | 39 ++++++++--- src/gt4py/next/ffront/lowering_utils.py | 31 +++++++++ src/gt4py/next/utils.py | 68 +++++++++++++++++-- .../ffront_tests/test_foast_to_gtir.py | 50 +++++++++++++- 5 files changed, 174 insertions(+), 16 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 2c2d4b6c58..6c5c8d9f2a 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1655,3 +1655,5 @@ def to_set(self) -> Set[T]: xiter = XIterable + +__all__ = ["toolz"] diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index a8967ef9ed..89d8aca795 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -15,13 +15,14 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, utils as eve_utils from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator +from gt4py.next import utils as next_utils from gt4py.next.ffront import ( dialect_ast_enums, fbuiltins, field_operator_ast as foast, + lowering_utils, stages as ffront_stages, type_specifications as ts_ffront, ) @@ -69,7 +70,9 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): [Sym(id=SymbolName('inp'))] """ - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) + uid_generator: eve_utils.UIDGenerator = dataclasses.field( + default_factory=eve_utils.UIDGenerator + ) @classmethod def apply(cls, node: foast.LocatedNode) -> itir.Expr: @@ -350,15 +353,29 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # obj.type, # ) - # def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - # condition, true_value, false_value = node.args + def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + if not isinstance(node.type, ts.TupleType): + return im.op_as_fieldop("if_")(*self.visit(node.args)) - # lowered_condition = self.visit(condition, **kwargs) - # return lowering_utils.process_elements( - # lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), - # [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - # node.type, - # ) + cond_ = self.visit(node.args[0]) + cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" + + true_structure = lowering_utils.expand_tuple_expr(self.visit(node.args[1]), node.type) + false_structure = lowering_utils.expand_tuple_expr(self.visit(node.args[2]), node.type) + + tree_zip = next_utils.tree_map(result_collection_type=list)(lambda x, y: (x, y))( + true_structure, false_structure + ) + + def create_if(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: + true_, false_ = true_false_ + return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) + + result = next_utils.tree_map( + collection_type=list, result_collection_type=lambda x: im.make_tuple(*x) + )(create_if)(tree_zip) + + return im.let(cond_symref_name, cond_)(result) # _visit_concat_where = _visit_where diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index cde34f315a..510b79d8c1 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -14,6 +14,7 @@ from typing import Any, Callable, TypeVar from gt4py.eve import utils as eve_utils +from gt4py.next import utils as next_utils from gt4py.next.ffront import type_info as ti_ffront from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -153,3 +154,33 @@ def _process_elements_impl( result = process_func(*_current_el_exprs) return result + + +def expand_tuple_expr(tup: itir.Expr, tup_type: ts.TypeSpec) -> tuple[itir.Expr | tuple, ...]: + """ + Create a tuple of `tuple_get` calls on `tup` by using the structure provided by `tup_type`. + + Examples: + >>> expand_tuple_expr( + ... itir.SymRef(id="tup"), + ... ts.TupleType( + ... types=[ + ... ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), + ... ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), + ... ] + ... ), + ... ) + (FunCall(fun=SymRef(id=SymbolRef('tuple_get')), args=[Literal(value='0', type=ScalarType(kind=, shape=None)), SymRef(id=SymbolRef('tup'))]), FunCall(fun=SymRef(id=SymbolRef('tuple_get')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), SymRef(id=SymbolRef('tup'))])) + """ + + def _tup_get(index_and_type: tuple[int, ts.TypeSpec]) -> itir.Expr: + i, _ = index_and_type + return im.tuple_get(i, tup) + + res = next_utils.tree_map(collection_type=list, result_collection_type=tuple)(_tup_get)( + next_utils.tree_enumerate( + tup_type, collection_type=ts.TupleType, result_collection_type=list + ) + ) + assert isinstance(res, tuple) # for mypy + return res diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 6aa70dab5b..e48ea9733b 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -15,6 +15,8 @@ import functools from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload +from gt4py.eve.utils import toolz + class RecursionGuard: """ @@ -62,7 +64,11 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: # TODO(havogt): remove flatten duplications in the whole codebase -def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: +def flatten_nested_tuple( + value: tuple[ + _T | tuple, ... + ], # `_T` omitted on purpose as type of `value`, to properly deduce `_T` on the user-side +) -> tuple[_T, ...]: if isinstance(value, tuple): return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting else: @@ -75,14 +81,18 @@ def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, . @overload def tree_map( - *, collection_type: type | tuple[type, ...], result_collection_type: Optional[type] = None + *, + collection_type: type | tuple[type, ...] = tuple, + result_collection_type: Optional[type | Callable] = None, ) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... def tree_map( *args: Callable[_P, _R], collection_type: type | tuple[type, ...] = tuple, - result_collection_type: Optional[type] = None, + result_collection_type: Optional[ + type | Callable + ] = None, # TODO consider renaming to `result_collection_constructor` ) -> ( Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]] @@ -112,7 +122,7 @@ def tree_map( if result_collection_type is None: if isinstance(collection_type, tuple): raise TypeError( - "tree_map() requires `result_collection_type` when `collection_type` is a tuple." + "tree_map() requires `result_collection_type` when `collection_type` is a tuple of types." ) result_collection_type = collection_type @@ -142,3 +152,53 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: raise TypeError( "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_type`." ) + + +_Type = TypeVar("_Type", bound=type) +_RType = TypeVar("_RType", bound=type) + + +@overload +def tree_enumerate( + collection: _Type | _T, + collection_type: _Type | tuple[_Type, ...], + result_collection_type: _RType, +) -> _RType: ... + + +@overload +def tree_enumerate( + collection: _Type | _T, + collection_type: _Type | tuple[_Type, ...], + result_collection_type: Callable[[_Type | _T], _R] = toolz.identity, +) -> _R: ... + + +def tree_enumerate( + collection: _Type | _T, + collection_type: _Type | tuple[_Type, ...] = tuple, + result_collection_type: _RType | Callable[[_Type | _T], _R] = toolz.identity, +) -> _R | _RType: + """ + Recursively `enumerate`s elements in a nested collection. + + Examples: + >>> tree_enumerate("a") + 'a' + + >>> for elem in tree_enumerate(("a",)): + ... elem + (0, 'a') + + >>> for elem in tree_enumerate(("a", "b")): + ... elem + (0, 'a') + (1, 'b') + + >>> tree_enumerate(("a", ("b", "c")), result_collection_type=list) + [(0, 'a'), (1, [(0, 'b'), (1, 'c')])] + """ + return tree_map( + collection_type=collection_type, + result_collection_type=toolz.compose(result_collection_type, enumerate), + )(toolz.identity)(collection) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index c617ed589f..f16cd8c780 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -25,7 +25,17 @@ import pytest import gt4py.next as gtx -from gt4py.next import broadcast, float32, float64, int32, int64, max_over, min_over, neighbor_sum +from gt4py.next import ( + broadcast, + float32, + float64, + int32, + int64, + max_over, + min_over, + neighbor_sum, + where, +) from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.fbuiltins import exp, minimum @@ -33,6 +43,7 @@ from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_lambdas from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -133,6 +144,43 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.expr == reference +def test_where(): + def foo( + a: gtx.Field[[TDim], bool], b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64] + ): + return where(a, b, c) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop("if_")("a", "b", "c") + + assert lowered.expr == reference + + +def test_where_tuple(): + def foo( + a: gtx.Field[[TDim], bool], + b: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]], + c: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]], + ): + return where(a, b, c) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + lowered_inlined = inline_lambdas.InlineLambdas.apply( + lowered + ) # we generate a let for the condition which is removed by inlining for easier testing + + reference = im.make_tuple( + im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), + im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), + ) + + assert lowered_inlined.expr == reference + + # TODO (introduce neg/pos) # def test_unary_ops(): # def unary(inp: gtx.Field[[TDim], float64]): From c77380a705a59ef14899cdaa3ba72cd701bed54a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 7 Aug 2024 18:11:50 +0200 Subject: [PATCH 12/29] cond cases --- src/gt4py/next/ffront/foast_to_gtir.py | 109 ++++++++---------- src/gt4py/next/iterator/ir_utils/ir_makers.py | 7 +- .../ffront_tests/test_foast_to_gtir.py | 68 ++++++++++- 3 files changed, 124 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 89d8aca795..f2436b2aac 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -22,6 +22,7 @@ dialect_ast_enums, fbuiltins, field_operator_ast as foast, + foast_introspection, lowering_utils, stages as ffront_stages, type_specifications as ts_ffront, @@ -168,63 +169,61 @@ def visit_BlockStmt( assert inner_expr return inner_expr - # def visit_IfStmt( - # self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - # ) -> itir.Expr: - # # the lowered if call doesn't need to be lifted as the condition can only originate - # # from a scalar value (and not a field) - # assert ( - # isinstance(node.condition.type, ts.ScalarType) - # and node.condition.type.kind == ts.ScalarKind.BOOL - # ) + def visit_IfStmt( + self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any + ) -> itir.Expr: + assert ( + isinstance(node.condition.type, ts.ScalarType) + and node.condition.type.kind == ts.ScalarKind.BOOL + ) - # cond = self.visit(node.condition, **kwargs) + cond = self.visit(node.condition, **kwargs) - # return_kind: StmtReturnKind = deduce_stmt_return_kind(node) + return_kind = foast_introspection.deduce_stmt_return_kind(node) - # common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols + common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - # if return_kind is StmtReturnKind.NO_RETURN: - # # pack the common symbols into a tuple - # common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) + if return_kind is foast_introspection.StmtReturnKind.NO_RETURN: + # TODO document why this case should be handled in this way, not by the more general CONDITIONAL_RETURN - # # apply both branches and extract the common symbols through the prepared tuple - # true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - # false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) + # pack the common symbols into a tuple + common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - # # unpack the common symbols' tuple for `inner_expr` - # for i, sym in enumerate(common_symbols.keys()): - # inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) + # apply both branches and extract the common symbols through the prepared tuple + true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) + false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - # # here we assume neither branch returns - # return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - # inner_expr - # ) - # elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - # common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - # common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) + # unpack the common symbols' tuple for `inner_expr` + for i, sym in enumerate(common_symbols.keys()): + inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - # # wrap the inner expression in a lambda function. note that this increases the - # # operation count if both branches are evaluated. - # inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - # inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - # inner_expr = im.call(inner_expr_name)(*common_symrefs) + # here we assume neither branch returns + return im.let("__if_stmt_result", im.cond(cond, true_branch, false_branch))(inner_expr) + elif return_kind is foast_introspection.StmtReturnKind.CONDITIONAL_RETURN: + common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) + common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - # true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - # false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) + # wrap the inner expression in a lambda function. note that this increases the + # operation count if both branches are evaluated. + inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") + inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) + inner_expr = im.call(inner_expr_name)(*common_symrefs) - # return im.let(inner_expr_name, inner_expr_evaluator)( - # im.if_(im.deref(cond), true_branch, false_branch) - # ) + true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) + false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - # assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN + return im.let(inner_expr_name, inner_expr_evaluator)( + im.cond(cond, true_branch, false_branch) + ) + + assert return_kind is foast_introspection.StmtReturnKind.UNCONDITIONAL_RETURN - # # note that we do not duplicate `inner_expr` here since if both branches - # # return, `inner_expr` is ignored. - # true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - # false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) + # note that we do not duplicate `inner_expr` here since if both branches + # return, `inner_expr` is ignored. + true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) + false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - # return im.if_(im.deref(cond), true_branch, false_branch) + return im.cond(cond, true_branch, false_branch) def visit_Assign( self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any @@ -264,20 +263,14 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: return self._map(node.op.value, node.left, node.right) - # def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - # op = "if_" - # args = (node.condition, node.true_expr, node.false_expr) - # lowered_args: list[itir.Expr] = [ - # lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - # for arg in args - # ] - # if any(type_info.contains_local_field(arg.type) for arg in args): - # lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - # op = im.call("map_")(op) - - # return lowering_utils.to_tuples_of_iterator( - # im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - # ) + def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: + assert ( + isinstance(node.condition.type, ts.ScalarType) + and node.condition.type.kind == ts.ScalarKind.BOOL + ) + return im.cond( + self.visit(node.condition), self.visit(node.true_expr), self.visit(node.false_expr) + ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: return self._map(node.op.value, node.left, node.right) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 02eead7fb3..5c54913bee 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -247,10 +247,15 @@ def tuple_get(index: str | int, tuple_expr): def if_(cond, true_val, false_val): - """Create a not_ FunCall, shorthand for ``call("if_")(expr)``.""" + """Create a if_ FunCall, shorthand for ``call("if_")(expr)``.""" return call("if_")(cond, true_val, false_val) +def cond(cond, true_val, false_val): + """Create a cond FunCall, shorthand for ``call("cond")(expr)``.""" + return call("cond")(cond, true_val, false_val) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index f16cd8c780..998b9192e0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -43,7 +43,7 @@ from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.transforms import collapse_tuple, inline_lambdas from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -181,6 +181,72 @@ def foo( assert lowered_inlined.expr == reference +def test_ternary(): + def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): + return b if a else c + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.cond("a", "b", "c") + + assert lowered.expr == reference + + +def test_if_unconditional_return(): + def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): + if a: + return b + else: + return c + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.cond("a", "b", "c") + + assert lowered.expr == reference + + +def test_if_no_return(): + def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): + if a: + res = b + else: + res = c + return res + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined) + print(lowered_inlined) + + reference = im.tuple_get(0, im.cond("a", im.make_tuple("b"), im.make_tuple("c"))) + + assert lowered_inlined.expr == reference + + +def test_if_conditional_return(): + def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): + if a: + res = b + else: + if a: + return c + res = b + return res + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined) + + reference = im.cond("a", "b", im.cond("a", "c", "b")) + + assert lowered_inlined.expr == reference + + # TODO (introduce neg/pos) # def test_unary_ops(): # def unary(inp: gtx.Field[[TDim], float64]): From 3b4c4a21d359732f9bdef6df73553b81f1a6daef Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Aug 2024 08:20:20 +0200 Subject: [PATCH 13/29] implement astype lowering --- src/gt4py/next/ffront/foast_to_gtir.py | 46 +++++------ src/gt4py/next/ffront/lowering_utils.py | 78 +++++++------------ .../ffront_tests/test_foast_to_gtir.py | 34 ++++++++ 3 files changed, 84 insertions(+), 74 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index f2436b2aac..5e97b9237c 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -17,7 +17,6 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, utils as eve_utils from gt4py.eve.extended_typing import Never -from gt4py.next import utils as next_utils from gt4py.next.ffront import ( dialect_ast_enums, fbuiltins, @@ -335,16 +334,19 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: f"Call to object of type '{type(node.func.type).__name__}' not understood." ) - # def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - # assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - # obj, new_type = node.args[0], node.args[1].id - # return lowering_utils.process_elements( - # lambda x: im.promote_to_lifted_stencil( - # im.lambda_("it")(im.call("cast_")("it", str(new_type))) - # )(x), - # self.visit(obj, **kwargs), - # obj.type, - # ) + def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) + obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id + + def create_cast(expr: itir.Expr) -> itir.FunCall: + return im.as_fieldop( + im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) + )(expr) + + if not isinstance(node.type, ts.TupleType): + return create_cast(obj) + + return lowering_utils.process_elements(create_cast, obj, node.type) def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: if not isinstance(node.type, ts.TupleType): @@ -353,20 +355,18 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - true_structure = lowering_utils.expand_tuple_expr(self.visit(node.args[1]), node.type) - false_structure = lowering_utils.expand_tuple_expr(self.visit(node.args[2]), node.type) - - tree_zip = next_utils.tree_map(result_collection_type=list)(lambda x, y: (x, y))( - true_structure, false_structure - ) + def create_if(cond: itir.Expr): + def create_if_impl(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: + true_, false_ = true_false_ + return im.op_as_fieldop("if_")(cond, true_, false_) - def create_if(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: - true_, false_ = true_false_ - return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) + return create_if_impl - result = next_utils.tree_map( - collection_type=list, result_collection_type=lambda x: im.make_tuple(*x) - )(create_if)(tree_zip) + result = lowering_utils.process_elements( + create_if(im.ref(cond_symref_name)), + (self.visit(node.args[1]), self.visit(node.args[2])), + node.type, + ) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 510b79d8c1..2232f251fe 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +from collections.abc import Iterable from typing import Any, Callable, TypeVar from gt4py.eve import utils as eve_utils @@ -102,60 +103,9 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall: return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args)) -# TODO(tehrengruber): The code quality of this function is poor. We should rewrite it. -def process_elements( - process_func: Callable[..., itir.Expr], - objs: itir.Expr | list[itir.Expr], - current_el_type: ts.TypeSpec, -) -> itir.FunCall: - """ - Recursively applies a processing function to all primitive constituents of a tuple. - - Arguments: - process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`. - If multiple `objs` are given the callable takes equally many arguments. - objs: The object whose elements are to be transformed. - current_el_type: A type with the same structure as the elements of `objs`. The leaf-types - are not used and thus not relevant. - """ - if isinstance(objs, itir.Expr): - objs = [objs] - - _current_el_exprs = [ - im.ref(f"__val_{eve_utils.content_hash(obj)}") for i, obj in enumerate(objs) - ] - body = _process_elements_impl(process_func, _current_el_exprs, current_el_type) - - return im.let(*((f"__val_{eve_utils.content_hash(obj)}", obj) for i, obj in enumerate(objs)))( # type: ignore[arg-type] # mypy not smart enough - body - ) - - T = TypeVar("T", bound=itir.Expr, covariant=True) -def _process_elements_impl( - process_func: Callable[..., itir.Expr], _current_el_exprs: list[T], current_el_type: ts.TypeSpec -) -> itir.Expr: - if isinstance(current_el_type, ts.TupleType): - result = im.make_tuple( - *[ - _process_elements_impl( - process_func, - [im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs], - current_el_type.types[i], - ) - for i in range(len(current_el_type.types)) - ] - ) - elif type_info.contains_local_field(current_el_type): - raise NotImplementedError("Processing fields with local dimension is not implemented.") - else: - result = process_func(*_current_el_exprs) - - return result - - def expand_tuple_expr(tup: itir.Expr, tup_type: ts.TypeSpec) -> tuple[itir.Expr | tuple, ...]: """ Create a tuple of `tuple_get` calls on `tup` by using the structure provided by `tup_type`. @@ -184,3 +134,29 @@ def _tup_get(index_and_type: tuple[int, ts.TypeSpec]) -> itir.Expr: ) assert isinstance(res, tuple) # for mypy return res + + +def process_elements( + process_func: Callable[..., itir.Expr], + objs: itir.Expr | Iterable[itir.Expr], + current_el_type: ts.TypeSpec, +) -> itir.FunCall: + """ + Arguments: + process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`. + If multiple `objs` are given the callable takes equally many arguments. + objs: The object whose elements are to be transformed. + current_el_type: A type with the same structure as the elements of `objs`. The leaf-types + are not used and thus not relevant. + """ + if isinstance(objs, itir.Expr): + objs = (objs,) + zipper = lambda x: x + else: + zipper = lambda *x: x + expanded = [expand_tuple_expr(arg, current_el_type) for arg in objs] + tree_zip = next_utils.tree_map(result_collection_type=list)(zipper)(*expanded) + result = next_utils.tree_map( + collection_type=list, result_collection_type=lambda x: im.make_tuple(*x) + )(process_func)(tree_zip) + return result diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 998b9192e0..21b8b8afeb 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -26,6 +26,7 @@ import gt4py.next as gtx from gt4py.next import ( + astype, broadcast, float32, float64, @@ -247,6 +248,39 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): assert lowered_inlined.expr == reference +def test_astype(): + def foo(a: gtx.Field[[TDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + "a" + ) + + assert lowered.expr == reference + + +def test_astype_tuple(): + def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.make_tuple( + im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + im.tuple_get(0, "a") + ), + im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + im.tuple_get(1, "a") + ), + ) + + assert lowered.expr == reference + + # TODO (introduce neg/pos) # def test_unary_ops(): # def unary(inp: gtx.Field[[TDim], float64]): From ab42da235c9b5826cd4906e96c418cd647554dd2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Aug 2024 08:48:36 +0200 Subject: [PATCH 14/29] fix process_elements for itir lowering --- src/gt4py/next/ffront/lowering_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 2232f251fe..a0a33c5ad0 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -154,6 +154,10 @@ def process_elements( zipper = lambda x: x else: zipper = lambda *x: x + + if not isinstance(current_el_type, ts.TupleType): + return process_func(*objs) + expanded = [expand_tuple_expr(arg, current_el_type) for arg in objs] tree_zip = next_utils.tree_map(result_collection_type=list)(zipper)(*expanded) result = next_utils.tree_map( From ef40997b79639ffcbe1e9a2ca432fef0c5828247 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 14:43:24 +0200 Subject: [PATCH 15/29] fix pure scalar op lowering --- src/gt4py/next/ffront/foast_to_gtir.py | 6 ++++++ .../ffront_tests/test_foast_to_gtir.py | 18 +++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 5e97b9237c..4a98aa0436 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -426,6 +426,12 @@ def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: lowered_args = [self.visit(arg, **kwargs) for arg in args] + if all( + isinstance(t, ts.ScalarType) + for arg in args + for t in type_info.primitive_constituents(arg.type) + ): + return im.call(op)(*lowered_args) # scalar operation if any(type_info.contains_local_field(arg.type) for arg in args): lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] op = im.call("map_")(op) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 21b8b8afeb..46ef143c26 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -81,7 +81,7 @@ def foo(): assert lowered.expr == im.make_tuple(im.literal_from_value(1.0), im.literal_from_value(True)) -def test_scalar_arg(): +def test_field_and_scalar_arg(): def foo(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: return alpha * bar @@ -94,6 +94,18 @@ def foo(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64] assert lowered.expr == reference +def test_scalar_arg_only(): + def foo(bar: int64, alpha: int64) -> int64: + return alpha * bar + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.call("multiplies")("alpha", "bar") + + assert lowered.expr == reference + + def test_multicopy(): def foo(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): return inp1, inp2 @@ -448,7 +460,7 @@ def foo(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: reference = im.let( ssa.unique_name("tmp", 0), - im.op_as_fieldop("plus")( + im.call("plus")( im.literal("1", "int32"), im.literal("1", "int32"), ), @@ -536,7 +548,7 @@ def foo() -> bool: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("greater")( + reference = im.call("greater")( im.literal("3", "int32"), im.literal("4", "int32"), ) From 1bd0797b3e0a92bd199cddb70f0e56ab0ea56099 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 16:00:34 +0200 Subject: [PATCH 16/29] improve typing, unary --- src/gt4py/next/ffront/foast_to_gtir.py | 87 ++++--------------- src/gt4py/next/ffront/foast_to_itir.py | 4 +- src/gt4py/next/ffront/lowering_utils.py | 7 +- src/gt4py/next/utils.py | 11 ++- .../ffront_tests/test_foast_to_gtir.py | 54 +++++++----- .../ffront_tests/test_past_to_gtir.py | 13 +-- 6 files changed, 59 insertions(+), 117 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4a98aa0436..7c885d0934 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -1,16 +1,11 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import dataclasses from typing import Any, Callable, Optional @@ -97,60 +92,10 @@ def visit_FieldOperator( id=func_definition.id, params=func_definition.params, expr=new_body ) - # def visit_ScanOperator( - # self, node: foast.ScanOperator, **kwargs: Any - # ) -> itir.FunctionDefinition: - # # note: we don't need the axis here as this is handled by the program - # # decorator - # assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # # In iterator IR we didn't properly specify if this is legal, - # # however after lift-inlining the expressions are transformed back to literals. - # forward = im.deref(self.visit(node.forward, **kwargs)) - # init = lowering_utils.process_elements( - # im.deref, self.visit(node.init, **kwargs), node.init.type - # ) - - # # lower definition function - # func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - # new_body = im.let( - # func_definition.params[0].id, - # # promote carry to iterator of tuples - # # (this is the only place in the lowering were a variable is captured in a lifted lambda) - # lowering_utils.to_tuples_of_iterator( - # im.promote_to_const_iterator(func_definition.params[0].id), - # [*node.type.definition.pos_or_kw_args.values()][0], # [unnecessary-iterable-allocation-for-first-element] - # ), - # )( - # # the function itself returns a tuple of iterators, deref element-wise - # lowering_utils.process_elements( - # im.deref, func_definition.expr, node.type.definition.returns - # ) - # ) - - # stencil_args: list[itir.Expr] = [] - # assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - # for param, arg_type in zip( - # func_definition.params[1:], - # [*node.type.definition.pos_or_kw_args.values()][1:], - # strict=True, - # ): - # if isinstance(arg_type, ts.TupleType): - # # convert into iterator of tuples - # stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - # new_body = im.let( - # param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - # )(new_body) - # else: - # stencil_args.append(im.ref(param.id)) - - # definition = itir.Lambda(params=func_definition.params, expr=new_body) - - # body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - # return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) + def visit_ScanOperator( + self, node: foast.ScanOperator, **kwargs: Any + ) -> itir.FunctionDefinition: + raise NotImplementedError("TODO") def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -251,13 +196,11 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._map("not_", node.operand) - raise NotImplementedError("TODO neg/pos") - - # return self._map( - # node.op.value, - # foast.Constant(value="0", type=dtype, location=node.location), - # node.operand, - # ) + return self._map( + node.op.value, + foast.Constant(value="0", type=dtype, location=node.location), + node.operand, + ) def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: return self._map(node.op.value, node.left, node.right) @@ -334,7 +277,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: f"Call to object of type '{type(node.func.type).__name__}' not understood." ) - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id @@ -355,7 +298,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(cond: itir.Expr): + def create_if(cond: itir.Expr) -> Callable[[tuple[itir.Expr, itir.Expr]], itir.FunCall]: def create_if_impl(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: true_, false_ = true_false_ return im.op_as_fieldop("if_")(cond, true_, false_) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index b5d84d4bd6..766ed05837 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -332,7 +332,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: f"Call to object of type '{type(node.func.type).__name__}' not understood." ) - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id return lowering_utils.process_elements( @@ -343,7 +343,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: obj.type, ) - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: condition, true_value, false_value = node.args lowered_condition = self.visit(condition, **kwargs) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index b79bc14054..8d317f0fda 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Iterable from typing import Any, Callable, TypeVar from gt4py.eve import utils as eve_utils @@ -134,7 +135,7 @@ def process_elements( process_func: Callable[..., itir.Expr], objs: itir.Expr | Iterable[itir.Expr], current_el_type: ts.TypeSpec, -) -> itir.FunCall: +) -> itir.Expr: """ Arguments: process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`. @@ -143,11 +144,9 @@ def process_elements( current_el_type: A type with the same structure as the elements of `objs`. The leaf-types are not used and thus not relevant. """ + zipper = (lambda x: x) if isinstance(objs, itir.Expr) else (lambda *x: x) if isinstance(objs, itir.Expr): objs = (objs,) - zipper = lambda x: x - else: - zipper = lambda *x: x if not isinstance(current_el_type, ts.TupleType): return process_func(*objs) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 7bfa068244..440e2b41f9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -78,7 +78,9 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_type: Optional[type | Callable] = None, -) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... +) -> Callable[ + [Callable[_P, _R]], Callable[..., Any] +]: ... # TODO(havogt): if result_collection_type is Callable, improve typing def tree_map( @@ -87,10 +89,7 @@ def tree_map( result_collection_type: Optional[ type | Callable ] = None, # TODO consider renaming to `result_collection_constructor` -) -> ( - Callable[..., _R | tuple[_R | tuple, ...]] - | Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]] -): +) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: """ Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). @@ -170,7 +169,7 @@ def tree_enumerate( def tree_enumerate( collection: _Type | _T, - collection_type: _Type | tuple[_Type, ...] = tuple, + collection_type: _Type | tuple[_Type, ...] = tuple, # type: ignore[assignment] # don't understand why mypy complains result_collection_type: _RType | Callable[[_Type | _T], _R] = toolz.identity, ) -> _R | _RType: """ diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 46ef143c26..63a5d7530f 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -1,3 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + # GT4Py - GridTools Framework # # Copyright (c) 2014-2023, ETH Zurich @@ -294,30 +302,28 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): # TODO (introduce neg/pos) -# def test_unary_ops(): -# def unary(inp: gtx.Field[[TDim], float64]): -# tmp = +inp -# tmp = -tmp -# return tmp - -# parsed = FieldOperatorParser.apply_to_function(unary) -# lowered = FieldOperatorLowering.apply(parsed) - -# reference = im.let( -# ssa.unique_name("tmp", 0), -# im.promote_to_lifted_stencil("plus")( -# im.promote_to_const_iterator(im.literal("0", "float64")), "inp" -# ), -# )( -# im.let( -# ssa.unique_name("tmp", 1), -# im.promote_to_lifted_stencil("minus")( -# im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) -# ), -# )(ssa.unique_name("tmp", 1)) -# ) - -# assert lowered.expr == reference +def test_unary_minus(): + def foo(inp: gtx.Field[[TDim], float64]): + return -inp + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop("minus")(im.literal("0", "float64"), "inp") + + assert lowered.expr == reference + + +def test_unary_plus(): + def foo(inp: gtx.Field[[TDim], float64]): + return +inp + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop("plus")(im.literal("0", "float64"), "inp") + + assert lowered.expr == reference def test_unpacking(): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index 4dacba9546..8543c904f8 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -1,16 +1,11 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import re From f7792ff454d791e6ea3a6c8656b7986ee4db28aa Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 16:52:23 +0200 Subject: [PATCH 17/29] as_offset --- src/gt4py/next/ffront/foast_to_gtir.py | 27 +++++++++---------- .../ffront_tests/test_foast_to_gtir.py | 18 ++++++++++++- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 7c885d0934..ce1e1adb92 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -221,20 +221,22 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: match node.args[0]: case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): shift_offset = im.shift(offset_name, offset_index) + return im.as_fieldop(im.lambda_("__it")(im.deref(shift_offset("__it"))))( + self.visit(node.func, **kwargs) + ) case foast.Name(id=offset_name): return im.as_fieldop_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # case foast.Call(func=foast.Name(id="as_offset")): - # func_args = node.args[0] - # offset_dim = func_args.args[0] - # assert isinstance(offset_dim, foast.Name) - # shift_offset = im.shift( - # offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs)) - # ) + case foast.Call(func=foast.Name(id="as_offset")): + # TODO(havogt): discuss this representation + func_args = node.args[0] + offset_dim = func_args.args[0] + assert isinstance(offset_dim, foast.Name) + shift_offset = im.shift(offset_dim.id, im.deref("__offset")) + return im.as_fieldop( + im.lambda_("__it", "__offset")(im.deref(shift_offset("__it"))) + )(self.visit(node.func, **kwargs), self.visit(func_args.args[1], **kwargs)) case _: raise FieldOperatorLoweringError("Unexpected shift arguments!") - return im.as_fieldop(im.lambda_("it")(im.deref(shift_offset("it"))))( - self.visit(node.func, **kwargs) - ) def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: @@ -267,9 +269,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # scan operators return an iterator of tuples, transform into tuples of iterator again if isinstance(node.func.type, ts_ffront.ScanOperatorType): raise NotImplementedError("TODO") - # result = lowering_utils.to_tuples_of_iterator( - # result, node.func.type.definition.returns - # ) return result @@ -313,7 +312,7 @@ def create_if_impl(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: return im.let(cond_symref_name, cond_)(result) - # _visit_concat_where = _visit_where + _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self.visit(node.args[0], **kwargs) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 63a5d7530f..48a17e16e0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -47,6 +47,7 @@ ) from gt4py.next.ffront import type_specifications as ts_ffront from gt4py.next.ffront.ast_passes import single_static_assign as ssa +from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.fbuiltins import exp, minimum from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -64,6 +65,7 @@ V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +TOff = gtx.FieldOffset("TDim", source=TDim, target=(TDim,)) UDim = gtx.Dimension("UDim") # Meaningless dimension, used for tests. @@ -135,7 +137,21 @@ def foo(inp: gtx.Field[[IDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") + reference = im.as_fieldop(im.lambda_("__it")(im.deref(im.shift("Ioff", 1)("__it"))))("inp") + + assert lowered.expr == reference + + +def test_as_offset(): + def foo(inp: gtx.Field[[TDim], float64], offset: gtx.Field[[TDim], int]): + return inp(as_offset(TOff, offset)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.as_fieldop( + im.lambda_("__it", "__offset")(im.deref(im.shift("TOff", im.deref("__offset"))("__it"))) + )("inp", "offset") assert lowered.expr == reference From 0ea0f7984e195de4ad41ec8246218f5154b6387d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 16:55:33 +0200 Subject: [PATCH 18/29] remove old license --- .../ffront_tests/test_foast_to_gtir.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 48a17e16e0..975ef714ac 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -6,24 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. from __future__ import annotations From 70dd170f1168eead50de0d89e18d97284438e29d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 16:57:36 +0200 Subject: [PATCH 19/29] disable gtir embedded --- .../feature_tests/ffront_tests/ffront_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index ed6fb7ab8c..b465ba84be 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -38,7 +38,7 @@ def __call__(self, program, *args, **kwargs) -> None: @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, + # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # TODO(havogt): enable once all incredients for GTIR are available # noqa: ERA001 next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, From 99f86c6b6f076d68337f6217b674199fa36a33c2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Aug 2024 17:02:15 +0200 Subject: [PATCH 20/29] cleanup use of dimension --- .../ffront_tests/test_foast_to_gtir.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 975ef714ac..753608aaa0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -40,15 +40,14 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -IDim = gtx.Dimension("IDim") Edge = gtx.Dimension("Edge") Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. + +TDim = gtx.Dimension("TDim") TOff = gtx.FieldOffset("TDim", source=TDim, target=(TDim,)) -UDim = gtx.Dimension("UDim") # Meaningless dimension, used for tests. +UDim = gtx.Dimension("UDim") def test_return(): @@ -74,7 +73,7 @@ def foo(): def test_field_and_scalar_arg(): - def foo(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: + def foo(bar: gtx.Field[[TDim], int64], alpha: int64) -> gtx.Field[[TDim], int64]: return alpha * bar # TODO document that scalar arguments of `as_fieldop(stencil)` are promoted to 0-d fields @@ -99,7 +98,7 @@ def foo(bar: int64, alpha: int64) -> int64: def test_multicopy(): - def foo(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): + def foo(inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64]): return inp1, inp2 parsed = FieldOperatorParser.apply_to_function(foo) @@ -111,15 +110,13 @@ def foo(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): def test_premap(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def foo(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) + def foo(inp: gtx.Field[[TDim], float64]): + return inp(TOff[1]) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.as_fieldop(im.lambda_("__it")(im.deref(im.shift("Ioff", 1)("__it"))))("inp") + reference = im.as_fieldop(im.lambda_("__it")(im.deref(im.shift("TOff", 1)("__it"))))("inp") assert lowered.expr == reference @@ -443,7 +440,7 @@ def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_add_scalar_literal_to_field(): - def foo(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: + def foo(a: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: return 2.0 + a parsed = FieldOperatorParser.apply_to_function(foo) @@ -455,7 +452,7 @@ def foo(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: def test_add_scalar_literals(): - def foo(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: + def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]: tmp = int32(1) + int32("1") return a + tmp @@ -522,7 +519,7 @@ def foo(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): def test_scalar_and(): - def foo(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: + def foo(a: gtx.Field[[TDim], "bool"]) -> gtx.Field[[TDim], "bool"]: return a & False parsed = FieldOperatorParser.apply_to_function(foo) @@ -598,8 +595,8 @@ def foo(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): def test_compare_chain(): def foo( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: + a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64] + ) -> gtx.Field[[TDim], bool]: return a > b > c parsed = FieldOperatorParser.apply_to_function(foo) From 3458aee3d9ff65dcb62a45a4deb7716d0269956a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 15 Aug 2024 12:21:59 +0200 Subject: [PATCH 21/29] cleanup --- src/gt4py/next/ffront/foast_to_gtir.py | 5 +- src/gt4py/next/ffront/lowering_utils.py | 80 +++++++++---------- src/gt4py/next/field_utils.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 4 +- .../iterator/transforms/collapse_tuple.py | 4 +- .../inline_center_deref_lift_vars.py | 2 +- .../iterator/transforms/propagate_deref.py | 2 +- src/gt4py/next/type_system/type_info.py | 1 - src/gt4py/next/utils.py | 80 ++++--------------- .../ffront_tests/test_foast_to_gtir.py | 34 +++++++- tests/next_tests/unit_tests/test_utils.py | 4 +- 11 files changed, 94 insertions(+), 124 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index ce1e1adb92..abd07e4747 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -297,9 +297,8 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(cond: itir.Expr) -> Callable[[tuple[itir.Expr, itir.Expr]], itir.FunCall]: - def create_if_impl(true_false_: tuple[itir.Expr, itir.Expr]) -> itir.FunCall: - true_, false_ = true_false_ + def create_if(cond: itir.Expr) -> Callable[[itir.Expr, itir.Expr], itir.FunCall]: + def create_if_impl(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: return im.op_as_fieldop("if_")(cond, true_, false_) return create_if_impl diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 8d317f0fda..49389f1704 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -10,7 +10,6 @@ from typing import Any, Callable, TypeVar from gt4py.eve import utils as eve_utils -from gt4py.next import utils as next_utils from gt4py.next.ffront import type_info as ti_ffront from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -98,45 +97,15 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall: return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args)) -T = TypeVar("T", bound=itir.Expr, covariant=True) - - -def expand_tuple_expr(tup: itir.Expr, tup_type: ts.TypeSpec) -> tuple[itir.Expr | tuple, ...]: - """ - Create a tuple of `tuple_get` calls on `tup` by using the structure provided by `tup_type`. - - Examples: - >>> expand_tuple_expr( - ... itir.SymRef(id="tup"), - ... ts.TupleType( - ... types=[ - ... ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), - ... ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), - ... ] - ... ), - ... ) - (FunCall(fun=SymRef(id=SymbolRef('tuple_get')), args=[Literal(value='0', type=ScalarType(kind=, shape=None)), SymRef(id=SymbolRef('tup'))]), FunCall(fun=SymRef(id=SymbolRef('tuple_get')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), SymRef(id=SymbolRef('tup'))])) - """ - - def _tup_get(index_and_type: tuple[int, ts.TypeSpec]) -> itir.Expr: - i, _ = index_and_type - return im.tuple_get(i, tup) - - res = next_utils.tree_map(collection_type=list, result_collection_type=tuple)(_tup_get)( - next_utils.tree_enumerate( - tup_type, collection_type=ts.TupleType, result_collection_type=list - ) - ) - assert isinstance(res, tuple) # for mypy - return res - - +# TODO(tehrengruber): The code quality of this function is poor. We should rewrite it. def process_elements( process_func: Callable[..., itir.Expr], objs: itir.Expr | Iterable[itir.Expr], current_el_type: ts.TypeSpec, -) -> itir.Expr: +) -> itir.FunCall: """ + Recursively applies a processing function to all primitive constituents of a tuple. + Arguments: process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`. If multiple `objs` are given the callable takes equally many arguments. @@ -144,16 +113,41 @@ def process_elements( current_el_type: A type with the same structure as the elements of `objs`. The leaf-types are not used and thus not relevant. """ - zipper = (lambda x: x) if isinstance(objs, itir.Expr) else (lambda *x: x) if isinstance(objs, itir.Expr): objs = (objs,) - if not isinstance(current_el_type, ts.TupleType): - return process_func(*objs) + let_ids = tuple(f"__val_{eve_utils.content_hash(obj)}" for obj in objs) + body = _process_elements_impl( + process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type + ) + + return im.let(*(zip(let_ids, objs)))(body) + + +T = TypeVar("T", bound=itir.Expr, covariant=True) + + +def _process_elements_impl( + process_func: Callable[..., itir.Expr], + _current_el_exprs: Iterable[T], + current_el_type: ts.TypeSpec, +) -> itir.Expr: + if isinstance(current_el_type, ts.TupleType): + result = im.make_tuple( + *[ + _process_elements_impl( + process_func, + tuple( + im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs + ), + current_el_type.types[i], + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + result = process_func(*_current_el_exprs) - expanded = [expand_tuple_expr(arg, current_el_type) for arg in objs] - tree_zip = next_utils.tree_map(result_collection_type=list)(zipper)(*expanded) - result = next_utils.tree_map( - collection_type=list, result_collection_type=lambda x: im.make_tuple(*x) - )(process_func)(tree_zip) return result diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py index 42004b2c4b..65865709ba 100644 --- a/src/gt4py/next/field_utils.py +++ b/src/gt4py/next/field_utils.py @@ -46,7 +46,7 @@ def field_from_typespec( (NumPyArrayField(... dtype=int32...), NumPyArrayField(... dtype=float32...)) """ - @utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple) + @utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=tuple) def impl(type_: ts.ScalarType) -> common.MutableField: res = common._field( xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)), diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index c2b9e934f9..cb6e9912d3 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import typing -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Optional, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -293,7 +293,7 @@ class let: def __init__(self, var: str | itir.Sym, init_form: itir.Expr | str): ... @typing.overload - def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr | str]]): ... + def __init__(self, *args: tuple[str | itir.Sym, itir.Expr | str]): ... def __init__(self, *args): if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7a54718a28..40d98208dd 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -253,7 +253,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. new_args.append(arg) if bound_vars: - return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough + return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: @@ -298,7 +298,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: inner_vars[arg_sym] = arg if outer_vars: return self.fp_transform( - im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough + im.let(*outer_vars.items())( self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) ) ) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 8a13d8d1b1..cdea1a7a48 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -89,6 +89,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): im.call(node.fun)(*new_args), eligible_params=eligible_params ) # TODO(tehrengruber): propagate let outwards - return im.let(*bound_scalars.items())(new_node) # type: ignore[arg-type] # mypy not smart enough + return im.let(*bound_scalars.items())(new_node) return node diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 8c78bc3171..a892d89c13 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -40,7 +40,7 @@ def visit_FunCall(self, node: ir.FunCall): if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let args: list[ir.Expr] = node.args[0].args - node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough + node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index e1bce10c05..372d21613e 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -136,7 +136,6 @@ class TupleConstructorType(Protocol, Generic[_R]): def __call__(self, *args: Any) -> _R: ... -# TODO(havogt): the complicated typing is a hint that this function needs refactoring def apply_to_primitive_constituents( fun: Callable[..., _T], *symbol_types: ts.TypeSpec, diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 440e2b41f9..44fa929e56 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -9,8 +9,6 @@ import functools from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload -from gt4py.eve.utils import toolz - class RecursionGuard: """ @@ -77,18 +75,16 @@ def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, . def tree_map( *, collection_type: type | tuple[type, ...] = tuple, - result_collection_type: Optional[type | Callable] = None, + result_collection_constructor: Optional[type | Callable] = None, ) -> Callable[ [Callable[_P, _R]], Callable[..., Any] -]: ... # TODO(havogt): if result_collection_type is Callable, improve typing +]: ... # TODO(havogt): if result_collection_constructor is Callable, improve typing def tree_map( *args: Callable[_P, _R], collection_type: type | tuple[type, ...] = tuple, - result_collection_type: Optional[ - type | Callable - ] = None, # TODO consider renaming to `result_collection_constructor` + result_collection_constructor: Optional[type | Callable] = None, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: """ Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). @@ -96,7 +92,7 @@ def tree_map( Args: fun: Function to apply to each entry of the collection. collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types. - result_collection_type: Type of the collection to be returned. If `None` the same type as `collection_type` is used. + result_collection_constructor: Type of the collection to be returned. If `None` the same type as `collection_type` is used. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -108,16 +104,18 @@ def tree_map( >>> tree_map(collection_type=list)(lambda x: x + 1)([[1, 2], 3]) [[2, 3], 4] - >>> tree_map(collection_type=list, result_collection_type=tuple)(lambda x: x + 1)([[1, 2], 3]) + >>> tree_map(collection_type=list, result_collection_constructor=tuple)(lambda x: x + 1)( + ... [[1, 2], 3] + ... ) ((2, 3), 4) """ - if result_collection_type is None: + if result_collection_constructor is None: if isinstance(collection_type, tuple): raise TypeError( - "tree_map() requires `result_collection_type` when `collection_type` is a tuple of types." + "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple of types." ) - result_collection_type = collection_type + result_collection_constructor = collection_type if len(args) == 1: fun = args[0] @@ -128,8 +126,8 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert all( isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args ) - assert result_collection_type is not None - return result_collection_type(impl(*arg) for arg in zip(*args)) + assert result_collection_constructor is not None + return result_collection_constructor(impl(*arg) for arg in zip(*args)) return fun( *cast(_P.args, args) @@ -140,58 +138,8 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: return functools.partial( tree_map, collection_type=collection_type, - result_collection_type=result_collection_type, + result_collection_constructor=result_collection_constructor, ) raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_type`." + "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." ) - - -_Type = TypeVar("_Type", bound=type) -_RType = TypeVar("_RType", bound=type) - - -@overload -def tree_enumerate( - collection: _Type | _T, - collection_type: _Type | tuple[_Type, ...], - result_collection_type: _RType, -) -> _RType: ... - - -@overload -def tree_enumerate( - collection: _Type | _T, - collection_type: _Type | tuple[_Type, ...], - result_collection_type: Callable[[_Type | _T], _R] = toolz.identity, -) -> _R: ... - - -def tree_enumerate( - collection: _Type | _T, - collection_type: _Type | tuple[_Type, ...] = tuple, # type: ignore[assignment] # don't understand why mypy complains - result_collection_type: _RType | Callable[[_Type | _T], _R] = toolz.identity, -) -> _R | _RType: - """ - Recursively `enumerate`s elements in a nested collection. - - Examples: - >>> tree_enumerate("a") - 'a' - - >>> for elem in tree_enumerate(("a",)): - ... elem - (0, 'a') - - >>> for elem in tree_enumerate(("a", "b")): - ... elem - (0, 'a') - (1, 'b') - - >>> tree_enumerate(("a", ("b", "c")), result_collection_type=list) - [(0, 'a'), (1, [(0, 'b'), (1, 'c')])] - """ - return tree_map( - collection_type=collection_type, - result_collection_type=toolz.compose(result_collection_type, enumerate), - )(toolz.identity)(collection) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 753608aaa0..6cdf4d261d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -236,7 +236,6 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined) - print(lowered_inlined) reference = im.tuple_get(0, im.cond("a", im.make_tuple("b"), im.make_tuple("c"))) @@ -283,6 +282,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( @@ -293,7 +293,37 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): ), ) - assert lowered.expr == reference + assert lowered_inlined.expr == reference + + +def test_astype_nested_tuple(): + def foo( + a: tuple[ + tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]], + gtx.Field[[TDim], float64], + ], + ): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) + + reference = im.make_tuple( + im.make_tuple( + im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + im.tuple_get(0, im.tuple_get(0, "a")) + ), + im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + im.tuple_get(1, im.tuple_get(0, "a")) + ), + ), + im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( + im.tuple_get(1, "a") + ), + ) + + assert lowered_inlined.expr == reference # TODO (introduce neg/pos) diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index f031d0f99c..a04ffecb9b 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -39,7 +39,7 @@ def testee(x): def test_tree_map_custom_output_type(): - @utils.tree_map(result_collection_type=list) + @utils.tree_map(result_collection_constructor=list) def testee(x): return x + 1 @@ -47,7 +47,7 @@ def testee(x): def test_tree_map_multiple_input_types(): - @utils.tree_map(collection_type=(list, tuple), result_collection_type=tuple) + @utils.tree_map(collection_type=(list, tuple), result_collection_constructor=tuple) def testee(x): return x + 1 From d46f3089911c4c7deb1d5051d8c197dc018f9ddf Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 08:55:30 +0200 Subject: [PATCH 22/29] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- src/gt4py/next/ffront/foast_to_gtir.py | 4 +--- src/gt4py/next/ffront/past_to_itir.py | 6 +++--- src/gt4py/next/iterator/transforms/propagate_deref.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 7 +++---- .../unit_tests/ffront_tests/test_past_to_gtir.py | 4 ++-- 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index abd07e4747..3c68a203e0 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -86,10 +86,8 @@ def visit_FieldOperator( ) -> itir.FunctionDefinition: func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = func_definition.expr - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body + id=func_definition.id, params=func_definition.params, expr=func_definition.expr ) def visit_ScanOperator( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 775a50b5e7..1b94bbbca1 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -33,7 +33,7 @@ @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): - to_gtir: bool = False + to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) @@ -151,7 +151,7 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # TODO(havogt): remove after refactoring to GTIR + to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -162,7 +162,7 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, + to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR ) -> itir.FencilDefinition: return cls(grid_type=grid_type, to_gtir=to_gtir).visit( node, function_definitions=function_definitions diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index a892d89c13..dd6edb4f4a 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -40,7 +40,7 @@ def visit_FunCall(self, node: ir.FunCall): if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let args: list[ir.Expr] = node.args[0].args - node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) + node = im.let(*zip(fun.params, args, strict=true)(im.deref(fun.expr)) elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 6cdf4d261d..849788c11b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -97,7 +97,7 @@ def foo(bar: int64, alpha: int64) -> int64: assert lowered.expr == reference -def test_multicopy(): +def test_multivalue_identity(): def foo(inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64]): return inp1, inp2 @@ -326,7 +326,6 @@ def foo( assert lowered_inlined.expr == reference -# TODO (introduce neg/pos) def test_unary_minus(): def foo(inp: gtx.Field[[TDim], float64]): return -inp @@ -600,10 +599,10 @@ def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): + def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a < b - parsed = FieldOperatorParser.apply_to_function(comp_lt) + parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop("less")("a", "b") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index 8543c904f8..a6231c22a7 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -125,7 +125,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) ], ), ) - fencil_pattern = P( + program_pattern = P( itir.Program, id=eve.SymbolName("copy_restrict_program"), params=[ @@ -137,7 +137,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) body=[set_at_pattern], ) - fencil_pattern.match(itir_node, raise_exception=True) + program_pattern.match(itir_node, raise_exception=True) def test_tuple_constructed_in_out_with_slicing(make_tuple_op): From cdc8963db40c1d7cf2b68b3264a5ef57a8204e27 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 09:11:29 +0200 Subject: [PATCH 23/29] fix import --- src/gt4py/next/ffront/foast_to_gtir.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3c68a203e0..301f0de695 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -10,10 +10,12 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator, PreserveLocationVisitor, utils as eve_utils +from gt4py import eve +from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Never from gt4py.next.ffront import ( dialect_ast_enums, + experimental as experimental_builtins, fbuiltins, field_operator_ast as foast, foast_introspection, @@ -21,8 +23,6 @@ stages as ffront_stages, type_specifications as ts_ffront, ) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts @@ -39,7 +39,7 @@ def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], it @dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): +class FieldOperatorLowering(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -239,14 +239,14 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: + elif isinstance(node.func, foast.Name) and node.func.id in fbuiltins.MATH_BUILTIN_NAMES: return self._visit_math_built_in(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + fbuiltins.FUN_BUILTIN_NAMES + experimental_builtins.EXPERIMENTAL_FUN_BUILTIN_NAMES ): visitor = getattr(self, f"_visit_{node.func.id}") return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: + elif isinstance(node.func, foast.Name) and node.func.id in fbuiltins.TYPE_BUILTIN_NAMES: return self._visit_type_constr(node, **kwargs) elif isinstance( node.func.type, From e17ef1c479d47201edfdac5d93d704cc31228078 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 11:00:22 +0200 Subject: [PATCH 24/29] address review comments --- src/gt4py/next/ffront/foast_to_gtir.py | 9 ++++++--- src/gt4py/next/ffront/lowering_utils.py | 6 +++--- src/gt4py/next/otf/stages.py | 2 +- .../feature_tests/ffront_tests/ffront_test_utils.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 7 ++----- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 301f0de695..790556d425 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -41,10 +41,13 @@ def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], it @dataclasses.dataclass class FieldOperatorLowering(eve.PreserveLocationVisitor, eve.NodeTranslator): """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). + Lower FieldOperator AST (FOAST) to GTIR. - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. + Most expressions are lowered to `as_fieldop`ed stencils. + Pure scalar expressions are kept as scalar operations as they might appear outside of the stencil context, + e.g. in `cond`. + In arithemtic operations that involve a field and a scalar, the scalar is implicitly broadcasted to a field + in the `as_fieldop` call. Examples -------- diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 49389f1704..2f63be5687 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -121,7 +121,7 @@ def process_elements( process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type ) - return im.let(*(zip(let_ids, objs)))(body) + return im.let(*(zip(let_ids, objs, strict=True)))(body) T = TypeVar("T", bound=itir.Expr, covariant=True) @@ -134,7 +134,7 @@ def _process_elements_impl( ) -> itir.Expr: if isinstance(current_el_type, ts.TupleType): result = im.make_tuple( - *[ + *( _process_elements_impl( process_func, tuple( @@ -143,7 +143,7 @@ def _process_elements_impl( current_el_type.types[i], ) for i in range(len(current_el_type.types)) - ] + ) ) elif type_info.contains_local_field(current_el_type): raise NotImplementedError("Processing fields with local dimension is not implemented.") diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index e47fecad42..78ee8fede9 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -26,7 +26,7 @@ @dataclasses.dataclass(frozen=True) class ProgramCall: - """Iterator IR representaion of a program together with arguments to be passed to it.""" + """ITIR/GTIR representation of a program together with arguments to be passed to it.""" program: itir.FencilDefinition | itir.Program args: tuple[Any, ...] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index b465ba84be..87379acac3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -38,7 +38,7 @@ def __call__(self, program, *args, **kwargs) -> None: @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # TODO(havogt): enable once all incredients for GTIR are available # noqa: ERA001 + # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all incredients for GTIR are available next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 849788c11b..09aa0b8176 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -33,11 +33,9 @@ from gt4py.next.ffront.fbuiltins import exp, minimum from gt4py.next.ffront.foast_to_gtir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import collapse_tuple, inline_lambdas -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.type_system import type_specifications as ts, type_translation Edge = gtx.Dimension("Edge") @@ -76,7 +74,6 @@ def test_field_and_scalar_arg(): def foo(bar: gtx.Field[[TDim], int64], alpha: int64) -> gtx.Field[[TDim], int64]: return alpha * bar - # TODO document that scalar arguments of `as_fieldop(stencil)` are promoted to 0-d fields parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) From 789a21563b40a9ec5baf16842e11a40dae94d73c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 11:03:19 +0200 Subject: [PATCH 25/29] undo unneeded change --- src/gt4py/eve/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 6186146cd1..8cb68845d7 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1651,5 +1651,3 @@ def to_set(self) -> Set[T]: xiter = XIterable - -__all__ = ["toolz"] From 156b884e45c93115b602df5911d306d57d77d6f9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 11:23:57 +0200 Subject: [PATCH 26/29] domain in itir maker --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 40a2d575f2..a7dc201db9 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -363,7 +363,9 @@ def lifted_neighbors(offset, it) -> itir.Expr: return lift(lambda_("it")(neighbors(offset, "it")))(it) -def as_fieldop_neighbors(offset, it) -> itir.Expr: +def as_fieldop_neighbors( + offset: str | itir.OffsetLiteral, it: str | itir.Expr, domain: Optional[itir.FunCall] = None +) -> itir.Expr: """ Create a fieldop for neighbors call. @@ -372,7 +374,7 @@ def as_fieldop_neighbors(offset, it) -> itir.Expr: >>> str(as_fieldop_neighbors("off", "a")) '(⇑(λ(it) → neighbors(offₒ, it)))(a)' """ - return as_fieldop(lambda_("it")(neighbors(offset, "it")))(it) + return as_fieldop(lambda_("it")(neighbors(offset, "it")), domain)(it) def promote_to_const_iterator(expr: str | itir.Expr) -> itir.Expr: From 9d2c0342c49b916bc3b6126d03a7cdc4d6d86954 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Aug 2024 11:25:50 +0200 Subject: [PATCH 27/29] TODO to FIXME --- src/gt4py/next/ffront/foast_to_gtir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 790556d425..649bd089de 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -129,7 +129,7 @@ def visit_IfStmt( common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols if return_kind is foast_introspection.StmtReturnKind.NO_RETURN: - # TODO document why this case should be handled in this way, not by the more general CONDITIONAL_RETURN + # FIXME[#1582](havogt): document why this case should be handled in this way, not by the more general CONDITIONAL_RETURN # pack the common symbols into a tuple common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) From efd35870c82ba6f0e623a546d0760f51f992d257 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 26 Aug 2024 09:46:04 +0200 Subject: [PATCH 28/29] address review comments --- src/gt4py/next/ffront/foast_to_gtir.py | 13 +++++-------- src/gt4py/next/ffront/foast_to_itir.py | 2 ++ src/gt4py/next/ffront/past_to_itir.py | 1 + 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 649bd089de..f7072925aa 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -286,26 +286,23 @@ def create_cast(expr: itir.Expr) -> itir.FunCall: im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) )(expr) - if not isinstance(node.type, ts.TupleType): + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj) return lowering_utils.process_elements(create_cast, obj, node.type) def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - if not isinstance(node.type, ts.TupleType): + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return im.op_as_fieldop("if_")(*self.visit(node.args)) cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(cond: itir.Expr) -> Callable[[itir.Expr, itir.Expr], itir.FunCall]: - def create_if_impl(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: - return im.op_as_fieldop("if_")(cond, true_, false_) - - return create_if_impl + def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: + return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) result = lowering_utils.process_elements( - create_if(im.ref(cond_symref_name)), + create_if, (self.visit(node.args[1]), self.visit(node.args[2])), node.type, ) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 766ed05837..dbcd6c8d47 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +# FIXME[#1582](havogt): remove after refactoring to GTIR + import dataclasses from typing import Any, Callable, Optional diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 1b94bbbca1..61a96694ee 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -248,6 +248,7 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) + # FIXME[#1582](havogt): remove after refactoring to GTIR def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) From c94e1e940f2ff53e19f887de63136b4802522bf7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 26 Aug 2024 09:52:00 +0200 Subject: [PATCH 29/29] add testcase for tuple in frozen namespace --- src/gt4py/next/ffront/foast_to_gtir.py | 4 ++++ .../ffront_tests/test_foast_to_gtir.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index f7072925aa..31168159dc 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -355,6 +355,10 @@ def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: ) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: + if isinstance(type_, ts.TupleType): + return im.make_tuple( + *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) + ) if isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() return im.literal(str(val), typename) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 09aa0b8176..367d536322 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -15,6 +15,7 @@ import pytest import gt4py.next as gtx +from gt4py.eve import utils as eve_utils from gt4py.next import ( astype, broadcast, @@ -496,6 +497,23 @@ def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]: assert lowered.expr == reference +def test_literal_tuple(): + tup = eve_utils.FrozenNamespace(a=(4, 2)) + + def foo(): + return tup.a + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.make_tuple( + im.literal("4", "int32"), + im.literal("2", "int32"), + ) + + assert lowered.expr == reference + + def test_binary_mult(): def foo(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a * b