From b36fcb3a3ce3cc8c82ceec2df068aabe7c503ae3 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Apr 2024 11:28:47 +0200 Subject: [PATCH 001/235] Skeleton for ITIR translation --- .../runners/dace_fieldview/__init__.py | 13 + .../runners/dace_fieldview/itir_taskgen.py | 54 ++++ .../runners/dace_fieldview/itir_to_sdfg.py | 233 ++++++++++++++++++ .../runners/dace_fieldview/itir_to_tasklet.py | 116 +++++++++ .../runners_tests/test_dace_fieldview.py | 119 +++++++++ 5 files changed, 535 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py new file mode 100644 index 0000000000..9b1371964b --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py @@ -0,0 +1,54 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import dace + + +class ItirTaskgenContext: + sdfg: dace.SDFG + state: dace.SDFGState + node_mapping: dict[str, dace.nodes.AccessNode] + symrefs: list[str] + + def __init__( + self, + current_sdfg: dace.SDFG, + current_state: dace.SDFGState, + ): + self.sdfg = current_sdfg + self.state = current_state + self.node_mapping = {} + self.symrefs = [] + + def add_node(self, data: str) -> dace.nodes.AccessNode: + assert data in self.sdfg.arrays + self.symrefs.append(data) + if data in self.node_mapping: + node = self.node_mapping[data] + else: + node = self.state.add_access(data) + self.node_mapping[data] = node + return node + + def clone(self) -> "ItirTaskgenContext": + ctx = ItirTaskgenContext(self.sdfg, self.state) + ctx.node_mapping = self.node_mapping + return ctx + + def tasklet_name(self) -> str: + return f"{self.state.label}_tasklet" + + def var_name(self) -> str: + return f"{self.state.label}_var" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py new file mode 100644 index 0000000000..af405e6150 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py @@ -0,0 +1,233 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Class to lower ITIR to SDFG. + +Note: this module covers the fieldview flavour of ITIR. +""" + +from collections import deque +from typing import Dict, List, Sequence, Tuple + +import dace + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts + +from .itir_taskgen import ItirTaskgenContext as TaskgenContext +from .itir_to_tasklet import ItirToTasklet + + +class ItirToSDFG(eve.NodeVisitor): + """Provides translation capability from an ITIR program to a DaCe SDFG. + + This class is responsible for translation of `ir.Program`, that is the top level representation + of a GT4Py program as a sequence of `it.Stmt` statements. + Each statement is translated to a taskgraph inside a separate state. The parent SDFG and + the translation state define the translation context, implemented by `TaskgenContext`. + Statement states are chained one after the other: potential concurrency between states should be + extracted by the DaCe SDFG transformations. + The program translation keeps track of entry and exit states: each statement is translated as + a new state inserted just before the exit state. Note that statements with branch execution might + result in more than one state. + """ + + _ctx_stack: deque[TaskgenContext] + _param_types: list[ts.TypeSpec] + + def __init__( + self, + param_types: list[ts.TypeSpec], + ): + self._ctx_stack = deque() + self._param_types = param_types + + def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: + # TODO define shape based on domain and dtype based on type inference + shape = [10] + dtype = dace.float64 + sdfg.add_array(name, shape, dtype, transient=False) + return + + def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Dict[str, str]: + raise NotImplementedError() + return {} + + def visit_Program(self, node: itir.Program) -> dace.SDFG: + """Translates `ir.Program` to `dace.SDFG`. + + First, it will allocate array and scalar storage for external (aka non-transient) + and local (aka transient) data. The local data, at this stage, is used + for temporary declarations, which should be available everywhere in the SDFG + but not outside. + Then, all statements are translated, one after the other in separate states. + """ + if node.function_definitions: + raise NotImplementedError("Functions expected to be inlined as lambda calls.") + + sdfg = dace.SDFG(node.id) + + # we use entry/exit state to keep track of entry/exit point of graph execution + entry_state = sdfg.add_state("program_entry", is_start_block=True) + + # declarations of temporaries result in local (aka transient) array definitions in the SDFG + if node.declarations: + temp_symbols: dict[str, str] = {} + for decl in node.declarations: + temp_symbols |= self._add_storage_for_temporary(decl) + + # define symbols for shape and offsets of temporary arrays as interstate edge symbols + # TODO(edopao): use new `add_state_after` function in next dace release + temp_state = sdfg.add_state("init_symbols_for_temporaries") + sdfg.add_edge(entry_state, temp_state, dace.InterstateEdge(assignments=temp_symbols)) + + exit_state = sdfg.add_state_after(temp_state, "program_exit") + else: + exit_state = sdfg.add_state_after(entry_state, "program_exit") + + # add global arrays (aka non-transient) to the SDFG + for param, type_ in zip(node.params, self._param_types): + self._add_storage(sdfg, str(param.id), type_) + + # create root context with exit state + root_ctx = TaskgenContext(sdfg, exit_state) + self._ctx_stack.append(root_ctx) + + # visit one statement at a time and put it into separate state + for i, stmt in enumerate(node.body): + stmt_state = sdfg.add_state_before(exit_state, f"stmt_{i}") + stmt_ctx = TaskgenContext(sdfg, stmt_state) + self._ctx_stack.append(stmt_ctx) + self.visit(stmt) + self._ctx_stack.pop() + + assert len(self._ctx_stack) == 1 + assert self._ctx_stack[-1] == root_ctx + + sdfg.validate() + return sdfg + + def visit_SetAt(self, stmt: itir.SetAt) -> None: + """Visits a statement expression and writes the local result to some external storage. + + Each statement expression results in some sort of taskgraph writing to local (aka transient) storage. + The translation of `SetAt` ensures that the result is written to the external storage. + """ + + assert len(self._ctx_stack) > 0 + ctx = self._ctx_stack[-1] + + # the statement expression will result in a tasklet writing to one or more local data nodes + self.visit(stmt.expr) + + # sanity check on stack status + assert ctx == self._ctx_stack[-1] + + # reset the list of visited symrefs to only discover output symrefs + tasklet_symrefs = ctx.symrefs.copy() + ctx.symrefs.clear() + + # the statement target will result in one or more access nodes to external data + self.visit(stmt.target) + target_symrefs = ctx.symrefs.copy() + + # sanity check on stack status + assert ctx == self._ctx_stack[-1] + + assert len(tasklet_symrefs) == len(target_symrefs) + for tasklet_sym, target_sym in zip(tasklet_symrefs, target_symrefs): + target_array = ctx.sdfg.arrays[target_sym] + assert not target_array.transient + + # TODO: visit statement domain to define the memlet subset + ctx.state.add_nedge( + ctx.node_mapping[tasklet_sym], + ctx.node_mapping[target_sym], + dace.Memlet.from_array(target_sym, target_array), + ) + + def _make_fieldop( + self, fun_node: itir.FunCall, fun_args: List[itir.Expr] + ) -> Sequence[Tuple[str, dace.nodes.AccessNode]]: + assert len(self._ctx_stack) != 0 + prev_ctx = self._ctx_stack[-1] + ctx = prev_ctx.clone() + self._ctx_stack.append(ctx) + + self.visit(fun_args) + + # create ordered list of input nodes + input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.symrefs] + + # TODO: define shape based on domain and dtype based on type inference + shape = [10] + dtype = dace.float64 + output_name, output_array = ctx.sdfg.add_array( + ctx.var_name(), shape, dtype, transient=True, find_new_name=True + ) + output_arrays = [(output_name, output_array)] + + assert len(fun_node.args) == 1 + tletgen = ItirToTasklet() + tlet_code, tlet_inputs, tlet_outputs = tletgen.visit(fun_node.args[0]) + + # TODO: define map range based on domain + map_ranges = dict(i="0:10") + + input_memlets: dict[str, dace.Memlet] = {} + for connector, (dname, _) in zip(tlet_inputs, input_arrays): + # TODO: define memlet subset based on domain + input_memlets[connector] = dace.Memlet(data=dname, subset="i") + + output_memlets: dict[str, dace.Memlet] = {} + output_nodes: list[Tuple[str, dace.nodes.AccessNode]] = [] + for connector, (dname, _) in zip(tlet_outputs, output_arrays): + # TODO: define memlet subset based on domain + output_memlets[connector] = dace.Memlet(data=dname, subset="i") + output_nodes.append((dname, ctx.add_node(dname))) + + ctx.state.add_mapped_tasklet( + ctx.tasklet_name(), + map_ranges, + input_memlets, + tlet_code, + output_memlets, + input_nodes=ctx.node_mapping, + output_nodes=ctx.node_mapping, + external_edges=True, + ) + + self._ctx_stack.pop() + assert prev_ctx == self._ctx_stack[-1] + + return output_nodes + + def visit_FunCall(self, node: itir.FunCall) -> None: + assert len(self._ctx_stack) > 0 + ctx = self._ctx_stack[-1] + + if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): + if node.fun.fun.id == "as_fieldop": + arg_nodes = self._make_fieldop(node.fun, node.args) + ctx.symrefs.extend([dname for dname, _ in arg_nodes]) + else: + raise NotImplementedError(f"Unexpected 'FunCall' with function {node.fun.fun.id}.") + else: + raise NotImplementedError(f"Unexpected 'FunCall' with type {type(node.fun)}.") + + def visit_SymRef(self, node: itir.SymRef) -> None: + dname = str(node.id) + ctx = self._ctx_stack[-1] + ctx.add_node(dname) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py new file mode 100644 index 0000000000..793127a926 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py @@ -0,0 +1,116 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Sequence, Tuple + +import numpy as np + +import gt4py.eve as eve +from gt4py.next.iterator import ir as itir + + +_MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +class ItirToTasklet(eve.NodeVisitor): + """Translates ITIR to Python code to be used as tasklet body. + + This class is dace agnostic: it receives ITIR as input and produces Python code. + TODO: Use `TemplatedGenerator` to implement this functionality, see `EmbeddedDSL` implementation. + """ + + def _visit_deref(self, node: itir.FunCall) -> str: + if not isinstance(node.args[0], itir.SymRef): + raise NotImplementedError( + f"Unexpected 'deref' argument with type '{type(node.args[0])}'." + ) + return self.visit(node.args[0]) + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = [self.visit(arg_node) for arg_node in node.args] + return fmt.format(*args) + + def visit_Lambda(self, node: itir.Lambda) -> Tuple[str, Sequence[str], Sequence[str]]: + params = [str(p.id) for p in node.params] + tlet_code = "_out = " + self.visit(node.expr) + + return tlet_code, params, ["_out"] + + def visit_FunCall(self, node: itir.FunCall) -> str: + if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": + return self._visit_deref(node) + if isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' with type '{type(node.fun)}'.") + + def visit_SymRef(self, node: itir.SymRef) -> str: + return str(node.id) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py new file mode 100644 index 0000000000..60b1c92d37 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Test that ITIR can be lowered to SDFG. + +Note: this test module covers the fieldview flavour of ITIR. +""" + +from gt4py.next.common import Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.runners.dace_fieldview.itir_to_sdfg import ( + ItirToSDFG as FieldviewItirToSDFG, +) +from gt4py.next.type_system import type_specifications as ts + +import numpy as np + +import pytest + +dace = pytest.importorskip("dace") + + +N = 10 +DIM = Dimension("D") +FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarKind.FLOAT64) + + +def test_itir_sum2(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) + ) + testee = itir.Program( + id="sum_2fields", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="z")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + ) + )("x", "y"), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.empty_like(a) + + sdfg_genenerator = FieldviewItirToSDFG( + param_types=([FTYPE] * 3), + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + sdfg(x=a, y=b, z=c) + assert np.allclose(c, (a + b)) + + +def test_itir_sum3(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) + ) + testee = itir.Program( + id="sum_3fields", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="w"), itir.Sym(id="z")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + ) + )( + "x", + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + ) + )("y", "w"), + ), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.random.rand(N) + d = np.empty_like(a) + + sdfg_genenerator = FieldviewItirToSDFG( + param_types=([FTYPE] * 4), + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + sdfg(x=a, y=b, w=c, z=d) + assert np.allclose(d, (a + b + c)) From 2020182e9c700b63406f27b00ecf4a17137ee182 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Apr 2024 11:49:59 +0200 Subject: [PATCH 002/235] Minor edit --- .../runners/dace_fieldview/itir_to_sdfg.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py index af405e6150..2850e7ffcd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py @@ -36,12 +36,13 @@ class ItirToSDFG(eve.NodeVisitor): This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `it.Stmt` statements. Each statement is translated to a taskgraph inside a separate state. The parent SDFG and - the translation state define the translation context, implemented by `TaskgenContext`. + the translation state define the translation context, implemented by `ItirTaskgenContext`. Statement states are chained one after the other: potential concurrency between states should be extracted by the DaCe SDFG transformations. The program translation keeps track of entry and exit states: each statement is translated as a new state inserted just before the exit state. Note that statements with branch execution might - result in more than one state. + result in more than one state. However, each statement should provide a single termination state + (e.g. a join state for an if/else branch execution) on the exit state of the program SDFG. """ _ctx_stack: deque[TaskgenContext] @@ -114,7 +115,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: self._ctx_stack.pop() assert len(self._ctx_stack) == 1 - assert self._ctx_stack[-1] == root_ctx + self._ctx_stack.pop() sdfg.validate() return sdfg @@ -135,19 +136,17 @@ def visit_SetAt(self, stmt: itir.SetAt) -> None: # sanity check on stack status assert ctx == self._ctx_stack[-1] - # reset the list of visited symrefs to only discover output symrefs - tasklet_symrefs = ctx.symrefs.copy() - ctx.symrefs.clear() - # the statement target will result in one or more access nodes to external data + target_ctx = ctx.clone() + self._ctx_stack.append(target_ctx) self.visit(stmt.target) - target_symrefs = ctx.symrefs.copy() + self._ctx_stack.pop() # sanity check on stack status assert ctx == self._ctx_stack[-1] - assert len(tasklet_symrefs) == len(target_symrefs) - for tasklet_sym, target_sym in zip(tasklet_symrefs, target_symrefs): + assert len(ctx.symrefs) == len(target_ctx.symrefs) + for tasklet_sym, target_sym in zip(ctx.symrefs, target_ctx.symrefs): target_array = ctx.sdfg.arrays[target_sym] assert not target_array.transient From 5c6b6bafc8a7c9e80acdead91a36706e310a7c1f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Apr 2024 08:08:07 +0200 Subject: [PATCH 003/235] Use Python callstack as a context stack for the ITIR visitor --- .../dace_fieldview/fieldview_dataflow.py | 79 +++++++++++++ .../runners/dace_fieldview/itir_taskgen.py | 54 --------- .../runners/dace_fieldview/itir_to_sdfg.py | 110 ++++++++---------- .../runners/dace_fieldview/itir_to_tasklet.py | 4 +- 4 files changed, 130 insertions(+), 117 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py new file mode 100644 index 0000000000..c7d4e0a4ad --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py @@ -0,0 +1,79 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import dace + + +class FieldviewRegion: + """Defines the dataflow scope of a fieldview expression. + + This class defines a region of the dataflow which represents a fieldview expression. + It usually consists of a map scope, with a set of input nodes that traverse the entry map; + a set of transient data nodes (aka temporaries) where the output memlets traversing the + exit map will write to; and the compute nodes (tasklets) implementing the expression + within the map scope. + More than one fieldfiew region can exist within a state. In this case, the temporaies + which are written to by one fieldview region will be inputs to the next region. Also, + the set of access nodes `node_mapping` is shared among all fieldview regions within a state. + + We use this class as return type when we visit a fieldview expression. It can be extended + with all informatiion needed to construct the dataflow graph. + """ + sdfg: dace.SDFG + state: dace.SDFGState + node_mapping: dict[str, dace.nodes.AccessNode] + + # ordered list of input/output data nodes used by the field operator being built in this dataflow region + input_nodes: list[str] + output_nodes: list[str] + + def __init__( + self, + current_sdfg: dace.SDFG, + current_state: dace.SDFGState, + ): + self.sdfg = current_sdfg + self.state = current_state + self.node_mapping = {} + self.input_nodes = [] + self.output_nodes = [] + + def _add_node(self, data: str) -> dace.nodes.AccessNode: + assert data in self.sdfg.arrays + if data in self.node_mapping: + node = self.node_mapping[data] + else: + node = self.state.add_access(data) + self.node_mapping[data] = node + return node + + def add_input_node(self, data: str) -> dace.nodes.AccessNode: + self.input_nodes.append(data) + return self._add_node(data) + + def add_output_node(self, data: str) -> dace.nodes.AccessNode: + self.output_nodes.append(data) + return self._add_node(data) + + def clone(self) -> "FieldviewRegion": + ctx = FieldviewRegion(self.sdfg, self.state) + ctx.node_mapping = self.node_mapping + return ctx + + def tasklet_name(self) -> str: + return f"{self.state.label}_tasklet" + + def var_name(self) -> str: + return f"{self.state.label}_var" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py deleted file mode 100644 index 9b1371964b..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_taskgen.py +++ /dev/null @@ -1,54 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import dace - - -class ItirTaskgenContext: - sdfg: dace.SDFG - state: dace.SDFGState - node_mapping: dict[str, dace.nodes.AccessNode] - symrefs: list[str] - - def __init__( - self, - current_sdfg: dace.SDFG, - current_state: dace.SDFGState, - ): - self.sdfg = current_sdfg - self.state = current_state - self.node_mapping = {} - self.symrefs = [] - - def add_node(self, data: str) -> dace.nodes.AccessNode: - assert data in self.sdfg.arrays - self.symrefs.append(data) - if data in self.node_mapping: - node = self.node_mapping[data] - else: - node = self.state.add_access(data) - self.node_mapping[data] = node - return node - - def clone(self) -> "ItirTaskgenContext": - ctx = ItirTaskgenContext(self.sdfg, self.state) - ctx.node_mapping = self.node_mapping - return ctx - - def tasklet_name(self) -> str: - return f"{self.state.label}_tasklet" - - def var_name(self) -> str: - return f"{self.state.label}_var" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py index 2850e7ffcd..16ba459427 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py @@ -17,8 +17,7 @@ Note: this module covers the fieldview flavour of ITIR. """ -from collections import deque -from typing import Dict, List, Sequence, Tuple +from typing import Dict, List, Optional import dace @@ -26,7 +25,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts -from .itir_taskgen import ItirTaskgenContext as TaskgenContext +from .fieldview_dataflow import FieldviewRegion from .itir_to_tasklet import ItirToTasklet @@ -45,14 +44,13 @@ class ItirToSDFG(eve.NodeVisitor): (e.g. a join state for an if/else branch execution) on the exit state of the program SDFG. """ - _ctx_stack: deque[TaskgenContext] + _ctx: Optional[FieldviewRegion] _param_types: list[ts.TypeSpec] def __init__( self, param_types: list[ts.TypeSpec], ): - self._ctx_stack = deque() self._param_types = param_types def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: @@ -82,6 +80,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # we use entry/exit state to keep track of entry/exit point of graph execution entry_state = sdfg.add_state("program_entry", is_start_block=True) + exit_state = sdfg.add_state_after(entry_state, "program_exit") # declarations of temporaries result in local (aka transient) array definitions in the SDFG if node.declarations: @@ -91,31 +90,22 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # define symbols for shape and offsets of temporary arrays as interstate edge symbols # TODO(edopao): use new `add_state_after` function in next dace release - temp_state = sdfg.add_state("init_symbols_for_temporaries") - sdfg.add_edge(entry_state, temp_state, dace.InterstateEdge(assignments=temp_symbols)) - - exit_state = sdfg.add_state_after(temp_state, "program_exit") + head_state = sdfg.add_state_before(exit_state, "init_symbols_for_temporaries") + (sdfg.edges_between(entry_state, head_state))[0].assignments = temp_symbols else: - exit_state = sdfg.add_state_after(entry_state, "program_exit") + head_state = entry_state # add global arrays (aka non-transient) to the SDFG for param, type_ in zip(node.params, self._param_types): self._add_storage(sdfg, str(param.id), type_) - # create root context with exit state - root_ctx = TaskgenContext(sdfg, exit_state) - self._ctx_stack.append(root_ctx) - + self._ctx = FieldviewRegion(sdfg, head_state) # visit one statement at a time and put it into separate state - for i, stmt in enumerate(node.body): - stmt_state = sdfg.add_state_before(exit_state, f"stmt_{i}") - stmt_ctx = TaskgenContext(sdfg, stmt_state) - self._ctx_stack.append(stmt_ctx) + for stmt in node.body: self.visit(stmt) - self._ctx_stack.pop() - assert len(self._ctx_stack) == 1 - self._ctx_stack.pop() + assert self._ctx.state == head_state + self._ctx = None sdfg.validate() return sdfg @@ -127,48 +117,48 @@ def visit_SetAt(self, stmt: itir.SetAt) -> None: The translation of `SetAt` ensures that the result is written to the external storage. """ - assert len(self._ctx_stack) > 0 - ctx = self._ctx_stack[-1] + prev_ctx = self._ctx + assert prev_ctx is not None + + stmt_ctx = prev_ctx.clone() + stmt_ctx.state = prev_ctx.sdfg.add_state_after(prev_ctx.state, "set_at") # the statement expression will result in a tasklet writing to one or more local data nodes + self._ctx = stmt_ctx self.visit(stmt.expr) - # sanity check on stack status - assert ctx == self._ctx_stack[-1] - - # the statement target will result in one or more access nodes to external data - target_ctx = ctx.clone() - self._ctx_stack.append(target_ctx) + # the target expression could be a `SymRef` to an output node or a `make_tuple` expression + # in case the statement returns more than one field + self._ctx = stmt_ctx.clone() self.visit(stmt.target) - self._ctx_stack.pop() - - # sanity check on stack status - assert ctx == self._ctx_stack[-1] + # the visit of a target expression should only produce a set of access nodes (no tasklets, no output nodes) + assert len(self._ctx.output_nodes) == 0 + stmt_ctx.output_nodes.extend(self._ctx.input_nodes) - assert len(ctx.symrefs) == len(target_ctx.symrefs) - for tasklet_sym, target_sym in zip(ctx.symrefs, target_ctx.symrefs): - target_array = ctx.sdfg.arrays[target_sym] - assert not target_array.transient + assert len(stmt_ctx.input_nodes) == len(stmt_ctx.output_nodes) + for tasklet_node, target_node in zip(stmt_ctx.input_nodes, stmt_ctx.output_nodes): + target_array = stmt_ctx.sdfg.arrays[target_node] + target_array.transient = False # TODO: visit statement domain to define the memlet subset - ctx.state.add_nedge( - ctx.node_mapping[tasklet_sym], - ctx.node_mapping[target_sym], - dace.Memlet.from_array(target_sym, target_array), + stmt_ctx.state.add_nedge( + stmt_ctx.node_mapping[tasklet_node], + stmt_ctx.node_mapping[target_node], + dace.Memlet.from_array(target_node, target_array), ) - def _make_fieldop( - self, fun_node: itir.FunCall, fun_args: List[itir.Expr] - ) -> Sequence[Tuple[str, dace.nodes.AccessNode]]: - assert len(self._ctx_stack) != 0 - prev_ctx = self._ctx_stack[-1] + self._ctx = prev_ctx + + def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: + prev_ctx = self._ctx + assert prev_ctx is not None ctx = prev_ctx.clone() - self._ctx_stack.append(ctx) + self._ctx = ctx self.visit(fun_args) # create ordered list of input nodes - input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.symrefs] + input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] # TODO: define shape based on domain and dtype based on type inference shape = [10] @@ -179,6 +169,8 @@ def _make_fieldop( output_arrays = [(output_name, output_array)] assert len(fun_node.args) == 1 + assert isinstance(fun_node.args[0], itir.Lambda) + tletgen = ItirToTasklet() tlet_code, tlet_inputs, tlet_outputs = tletgen.visit(fun_node.args[0]) @@ -191,11 +183,10 @@ def _make_fieldop( input_memlets[connector] = dace.Memlet(data=dname, subset="i") output_memlets: dict[str, dace.Memlet] = {} - output_nodes: list[Tuple[str, dace.nodes.AccessNode]] = [] for connector, (dname, _) in zip(tlet_outputs, output_arrays): # TODO: define memlet subset based on domain output_memlets[connector] = dace.Memlet(data=dname, subset="i") - output_nodes.append((dname, ctx.add_node(dname))) + ctx.add_output_node(dname) ctx.state.add_mapped_tasklet( ctx.tasklet_name(), @@ -208,19 +199,16 @@ def _make_fieldop( external_edges=True, ) - self._ctx_stack.pop() - assert prev_ctx == self._ctx_stack[-1] - - return output_nodes + self._ctx = prev_ctx + return ctx def visit_FunCall(self, node: itir.FunCall) -> None: - assert len(self._ctx_stack) > 0 - ctx = self._ctx_stack[-1] - + assert self._ctx is not None if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): if node.fun.fun.id == "as_fieldop": - arg_nodes = self._make_fieldop(node.fun, node.args) - ctx.symrefs.extend([dname for dname, _ in arg_nodes]) + child_ctx = self._make_fieldop(node.fun, node.args) + assert child_ctx.state == self._ctx.state + self._ctx.input_nodes.extend(child_ctx.output_nodes) else: raise NotImplementedError(f"Unexpected 'FunCall' with function {node.fun.fun.id}.") else: @@ -228,5 +216,5 @@ def visit_FunCall(self, node: itir.FunCall) -> None: def visit_SymRef(self, node: itir.SymRef) -> None: dname = str(node.id) - ctx = self._ctx_stack[-1] - ctx.add_node(dname) + assert self._ctx is not None + self._ctx.add_input_node(dname) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py index 793127a926..64f047e4e2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py @@ -78,11 +78,11 @@ class ItirToTasklet(eve.NodeVisitor): """Translates ITIR to Python code to be used as tasklet body. - This class is dace agnostic: it receives ITIR as input and produces Python code. - TODO: Use `TemplatedGenerator` to implement this functionality, see `EmbeddedDSL` implementation. + TODO: this class needs to be revisited in next commit. """ def _visit_deref(self, node: itir.FunCall) -> str: + # TODO: build memlet subset / shift pattern for each tasklet connector if not isinstance(node.args[0], itir.SymRef): raise NotImplementedError( f"Unexpected 'deref' argument with type '{type(node.args[0])}'." From 60e1c69ee859a9f6089c5d84a6c8a0f606d42bb7 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Apr 2024 08:10:16 +0200 Subject: [PATCH 004/235] Format error --- .../runners/dace_fieldview/fieldview_dataflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py index c7d4e0a4ad..265add27f6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py @@ -31,6 +31,7 @@ class FieldviewRegion: We use this class as return type when we visit a fieldview expression. It can be extended with all informatiion needed to construct the dataflow graph. """ + sdfg: dace.SDFG state: dace.SDFGState node_mapping: dict[str, dace.nodes.AccessNode] From 073a0a4e88f7bc76c8e1f0297a4ea7bee9d8e8e7 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Apr 2024 15:28:11 +0200 Subject: [PATCH 005/235] Refactor tasklet codegen --- ..._to_tasklet.py => gtir_tasklet_codegen.py} | 37 +++++------ .../{itir_to_sdfg.py => gtir_to_sdfg.py} | 66 +++++++++++++------ .../runners/dace_fieldview/gtir_to_tasklet.py | 56 ++++++++++++++++ .../runners_tests/test_dace_fieldview.py | 12 ++-- 4 files changed, 122 insertions(+), 49 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{itir_to_tasklet.py => gtir_tasklet_codegen.py} (73%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{itir_to_sdfg.py => gtir_to_sdfg.py} (84%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py similarity index 73% rename from src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 64f047e4e2..9b85e13650 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Sequence, Tuple +from typing import Sequence import numpy as np -import gt4py.eve as eve +from gt4py.eve import codegen from gt4py.next.iterator import ir as itir @@ -75,33 +75,26 @@ } -class ItirToTasklet(eve.NodeVisitor): - """Translates ITIR to Python code to be used as tasklet body. +class GtirTaskletCodegen(codegen.TemplatedGenerator): + """Translates GTIR to Python code to be used as tasklet body. - TODO: this class needs to be revisited in next commit. + This class is dace agnostic: it receives GTIR as input and produces Python code. """ - def _visit_deref(self, node: itir.FunCall) -> str: - # TODO: build memlet subset / shift pattern for each tasklet connector - if not isinstance(node.args[0], itir.SymRef): - raise NotImplementedError( - f"Unexpected 'deref' argument with type '{type(node.args[0])}'." - ) - return self.visit(node.args[0]) + def _visit_deref(self, node: itir.FunCall) -> list[str]: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + def _visit_numeric_builtin(self, node: itir.FunCall) -> Sequence[str]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = [self.visit(arg_node) for arg_node in node.args] - return fmt.format(*args) + args = self.visit(node.args) + expr = fmt.format(*args) + return [expr] - def visit_Lambda(self, node: itir.Lambda) -> Tuple[str, Sequence[str], Sequence[str]]: - params = [str(p.id) for p in node.params] - tlet_code = "_out = " + self.visit(node.expr) - - return tlet_code, params, ["_out"] - - def visit_FunCall(self, node: itir.FunCall) -> str: + def visit_FunCall(self, node: itir.FunCall) -> Sequence[str]: if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": return self._visit_deref(node) if isinstance(node.fun, itir.SymRef): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py similarity index 84% rename from src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 16ba459427..eda4072930 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -12,12 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later """ -Class to lower ITIR to SDFG. +Class to lower GTIR to SDFG. -Note: this module covers the fieldview flavour of ITIR. +Note: this module covers the fieldview flavour of GTIR. """ -from typing import Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import dace @@ -26,11 +26,41 @@ from gt4py.next.type_system import type_specifications as ts from .fieldview_dataflow import FieldviewRegion -from .itir_to_tasklet import ItirToTasklet +from .gtir_to_tasklet import GtirToTasklet -class ItirToSDFG(eve.NodeVisitor): - """Provides translation capability from an ITIR program to a DaCe SDFG. +def create_ctx_in_new_state(new_state_name: Optional[str] = None) -> Callable: + """Decorator to execute the visit function in a separate context, in a new state.""" + + def decorator(func: Callable) -> Callable: + def newf(self: "GtirToSDFG", *args: Any, **kwargs: Optional[Any]) -> FieldviewRegion: + prev_ctx = self._ctx + assert prev_ctx is not None + new_ctx = prev_ctx.clone() + if new_state_name: + new_ctx.state = prev_ctx.sdfg.add_state_after(prev_ctx.state, new_state_name) + self._ctx = new_ctx + + child_ctx = func(self, *args, **kwargs) + + assert self._ctx == new_ctx + self._ctx = prev_ctx + + return child_ctx + + return newf + + return decorator + + +def create_ctx(func: Callable) -> Callable: + """Decorator to execute the visit function in a separate context, in current state.""" + + return create_ctx_in_new_state()(func) + + +class GtirToSDFG(eve.NodeVisitor): + """Provides translation capability from an GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `it.Stmt` statements. @@ -110,6 +140,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: sdfg.validate() return sdfg + @create_ctx_in_new_state(new_state_name="set_at") def visit_SetAt(self, stmt: itir.SetAt) -> None: """Visits a statement expression and writes the local result to some external storage. @@ -117,23 +148,20 @@ def visit_SetAt(self, stmt: itir.SetAt) -> None: The translation of `SetAt` ensures that the result is written to the external storage. """ - prev_ctx = self._ctx - assert prev_ctx is not None - - stmt_ctx = prev_ctx.clone() - stmt_ctx.state = prev_ctx.sdfg.add_state_after(prev_ctx.state, "set_at") + stmt_ctx = self._ctx + assert stmt_ctx is not None - # the statement expression will result in a tasklet writing to one or more local data nodes - self._ctx = stmt_ctx self.visit(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field + # TODO: Use GtirToTasklet with new context without updating self._ctx self._ctx = stmt_ctx.clone() self.visit(stmt.target) # the visit of a target expression should only produce a set of access nodes (no tasklets, no output nodes) assert len(self._ctx.output_nodes) == 0 stmt_ctx.output_nodes.extend(self._ctx.input_nodes) + self._ctx = stmt_ctx assert len(stmt_ctx.input_nodes) == len(stmt_ctx.output_nodes) for tasklet_node, target_node in zip(stmt_ctx.input_nodes, stmt_ctx.output_nodes): @@ -147,13 +175,10 @@ def visit_SetAt(self, stmt: itir.SetAt) -> None: dace.Memlet.from_array(target_node, target_array), ) - self._ctx = prev_ctx - + @create_ctx def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: - prev_ctx = self._ctx - assert prev_ctx is not None - ctx = prev_ctx.clone() - self._ctx = ctx + ctx = self._ctx + assert ctx is not None self.visit(fun_args) @@ -171,7 +196,7 @@ def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> Fi assert len(fun_node.args) == 1 assert isinstance(fun_node.args[0], itir.Lambda) - tletgen = ItirToTasklet() + tletgen = GtirToTasklet(ctx) tlet_code, tlet_inputs, tlet_outputs = tletgen.visit(fun_node.args[0]) # TODO: define map range based on domain @@ -199,7 +224,6 @@ def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> Fi external_edges=True, ) - self._ctx = prev_ctx return ctx def visit_FunCall(self, node: itir.FunCall) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py new file mode 100644 index 0000000000..e9c043369a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -0,0 +1,56 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from typing import Sequence, Tuple + +import gt4py.eve as eve +from gt4py.next.iterator import ir as itir + +from .fieldview_dataflow import FieldviewRegion +from .gtir_tasklet_codegen import GtirTaskletCodegen + + +class GtirToTasklet(eve.NodeVisitor): + """Translates GTIR to Python code to be used as tasklet body. + + TODO: this class needs to be revisited in next commit. + """ + + _ctx: FieldviewRegion + + def __init__(self, ctx: FieldviewRegion): + self._ctx = ctx + + def _visit_deref(self, node: itir.FunCall) -> str: + # TODO: build memlet subset / shift pattern for each tasklet connector + if not isinstance(node.args[0], itir.SymRef): + raise NotImplementedError( + f"Unexpected 'deref' argument with type '{type(node.args[0])}'." + ) + return self.visit(node.args[0]) + + def visit_Lambda(self, node: itir.Lambda) -> Tuple[str, Sequence[str], Sequence[str]]: + params = [str(p.id) for p in node.params] + results = [] + + tlet_code_lines = [] + expr_list = GtirTaskletCodegen.apply(node.expr) + for i, expr in enumerate(expr_list): + outvar = f"__out_{i}" + results.append(outvar) + tlet_code_lines.append(outvar + " = " + expr) + tlet_code = "\n".join(tlet_code_lines) + + return tlet_code, params, results diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 60b1c92d37..31a482d6b3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -20,8 +20,8 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners.dace_fieldview.itir_to_sdfg import ( - ItirToSDFG as FieldviewItirToSDFG, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( + GtirToSDFG as FieldviewGtirToSDFG, ) from gt4py.next.type_system import type_specifications as ts @@ -37,7 +37,7 @@ FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarKind.FLOAT64) -def test_itir_sum2(): +def test_gtir_sum2(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) ) @@ -63,7 +63,7 @@ def test_itir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - sdfg_genenerator = FieldviewItirToSDFG( + sdfg_genenerator = FieldviewGtirToSDFG( param_types=([FTYPE] * 3), ) sdfg = sdfg_genenerator.visit(testee) @@ -74,7 +74,7 @@ def test_itir_sum2(): assert np.allclose(c, (a + b)) -def test_itir_sum3(): +def test_gtir_sum3(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) ) @@ -108,7 +108,7 @@ def test_itir_sum3(): c = np.random.rand(N) d = np.empty_like(a) - sdfg_genenerator = FieldviewItirToSDFG( + sdfg_genenerator = FieldviewGtirToSDFG( param_types=([FTYPE] * 4), ) sdfg = sdfg_genenerator.visit(testee) From 50be68fcabcbc18c87927e1ac929f183c5f1933a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 19 Apr 2024 16:00:05 +0200 Subject: [PATCH 006/235] Code refactoring --- .../dace_fieldview/fieldview_dataflow.py | 9 ++ .../dace_fieldview/gtir_fieldview_builder.py | 152 ++++++++++++++++++ .../runners/dace_fieldview/gtir_to_sdfg.py | 144 ++--------------- .../runners/dace_fieldview/gtir_to_tasklet.py | 56 ------- 4 files changed, 171 insertions(+), 190 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py index 265add27f6..77f84108c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py @@ -13,6 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Optional, Tuple, TypeAlias + import dace @@ -32,11 +34,16 @@ class FieldviewRegion: with all informatiion needed to construct the dataflow graph. """ + Connection: TypeAlias = Tuple[dace.nodes.Node, Optional[str]] + sdfg: dace.SDFG state: dace.SDFGState node_mapping: dict[str, dace.nodes.AccessNode] # ordered list of input/output data nodes used by the field operator being built in this dataflow region + input_connections: list[Connection] + output_connections: list[Connection] + input_nodes: list[str] output_nodes: list[str] @@ -50,6 +57,8 @@ def __init__( self.node_mapping = {} self.input_nodes = [] self.output_nodes = [] + self.input_connections = [] + self.output_connections = [] def _add_node(self, data: str) -> dace.nodes.AccessNode: assert data in self.sdfg.arrays diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py new file mode 100644 index 0000000000..a0d62f523d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -0,0 +1,152 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from typing import Any, Callable, List, Optional + +import dace + +import gt4py.eve as eve +from gt4py.next.iterator import ir as itir + +from .fieldview_dataflow import FieldviewRegion +from .gtir_tasklet_codegen import GtirTaskletCodegen as TaskletCodegen + + +class GtirFieldviewBuilder(eve.NodeVisitor): + """Translates GTIR to Python code to be used as tasklet body. + + TODO: this class needs to be revisited in next commit. + """ + + _ctx: FieldviewRegion + + def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState): + self._ctx = FieldviewRegion(sdfg, state) + + @staticmethod + def create_ctx(func: Callable) -> Callable: + def newf( + self: "GtirFieldviewBuilder", *args: Any, **kwargs: Optional[Any] + ) -> FieldviewRegion: + prev_ctx = self._ctx + new_ctx = prev_ctx.clone() + self._ctx = new_ctx + + child_ctx = func(self, *args, **kwargs) + + assert self._ctx == new_ctx + self._ctx = prev_ctx + + return child_ctx + + return newf + + def visit_FunCall(self, node: itir.FunCall) -> None: + if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): + if node.fun.fun.id == "as_fieldop": + child_ctx = self._make_fieldop(node.fun, node.args) + assert child_ctx.state == self._ctx.state + self._ctx.input_nodes.extend(child_ctx.output_nodes) + else: + raise NotImplementedError(f"Unexpected 'FunCall' with function {node.fun.fun.id}.") + else: + raise NotImplementedError(f"Unexpected 'FunCall' with type {type(node.fun)}.") + + def visit_Lambda(self, node: itir.Lambda) -> None: + params = [str(p.id) for p in node.params] + results = [] + + tlet_code_lines = [] + expr_list = TaskletCodegen.apply(node.expr) + + for i, expr in enumerate(expr_list): + outvar = f"__out_{i}" + tlet_code_lines.append(outvar + " = " + expr) + results.append(outvar) + tlet_code = "\n".join(tlet_code_lines) + + tlet_node: dace.tasklet = self._ctx.state.add_tasklet( + f"{self._ctx.state.label}_lambda", set(params), set(results), tlet_code + ) + + # TODO: distinguish between external and local connections (now assume all external) + for inpvar in params: + self._ctx.input_connections.append((tlet_node, inpvar)) + for outvar in results: + self._ctx.output_connections.append((tlet_node, outvar)) + + def visit_SymRef(self, node: itir.SymRef) -> None: + dname = str(node.id) + self._ctx.add_input_node(dname) + + def write_to(self, node: itir.Expr) -> None: + result_nodes = self._ctx.input_nodes.copy() + self._ctx = self._ctx.clone() + self.visit(node) + # the target expression should only produce a set of access nodes (no tasklets, no output nodes) + assert len(self._ctx.output_nodes) == 0 + output_nodes = self._ctx.input_nodes + + assert len(result_nodes) == len(output_nodes) + for tasklet_node, target_node in zip(result_nodes, output_nodes): + target_array = self._ctx.sdfg.arrays[target_node] + target_array.transient = False + + # TODO: visit statement domain to define the memlet subset + self._ctx.state.add_nedge( + self._ctx.node_mapping[tasklet_node], + self._ctx.node_mapping[target_node], + dace.Memlet.from_array(target_node, target_array), + ) + + @create_ctx + def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: + ctx = self._ctx + + self.visit(fun_args) + + # create ordered list of input nodes + input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] + + # TODO: define shape based on domain and dtype based on type inference + shape = [10] + dtype = dace.float64 + output_name, output_array = ctx.sdfg.add_array( + ctx.var_name(), shape, dtype, transient=True, find_new_name=True + ) + output_arrays = [(output_name, output_array)] + + assert len(fun_node.args) == 1 + self.visit(fun_node.args[0]) + + # TODO: define map range based on domain + map_ranges = dict(i="0:10") + me, mx = ctx.state.add_map("fieldop", map_ranges) + + for (node, connector), (dname, _) in zip(self._ctx.input_connections, input_arrays): + # TODO: define memlet subset based on domain + src_node = self._ctx.node_mapping[dname] + self._ctx.state.add_memlet_path( + src_node, me, node, dst_conn=connector, memlet=dace.Memlet(data=dname, subset="i") + ) + + for (node, connector), (dname, _) in zip(self._ctx.output_connections, output_arrays): + # TODO: define memlet subset based on domain + dst_node = ctx.add_output_node(dname) + self._ctx.state.add_memlet_path( + node, mx, dst_node, src_conn=connector, memlet=dace.Memlet(data=dname, subset="i") + ) + + return ctx diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index eda4072930..19949f78f4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -17,7 +17,7 @@ Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Callable, Dict, List, Optional +from typing import Dict import dace @@ -25,38 +25,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts -from .fieldview_dataflow import FieldviewRegion -from .gtir_to_tasklet import GtirToTasklet - - -def create_ctx_in_new_state(new_state_name: Optional[str] = None) -> Callable: - """Decorator to execute the visit function in a separate context, in a new state.""" - - def decorator(func: Callable) -> Callable: - def newf(self: "GtirToSDFG", *args: Any, **kwargs: Optional[Any]) -> FieldviewRegion: - prev_ctx = self._ctx - assert prev_ctx is not None - new_ctx = prev_ctx.clone() - if new_state_name: - new_ctx.state = prev_ctx.sdfg.add_state_after(prev_ctx.state, new_state_name) - self._ctx = new_ctx - - child_ctx = func(self, *args, **kwargs) - - assert self._ctx == new_ctx - self._ctx = prev_ctx - - return child_ctx - - return newf - - return decorator - - -def create_ctx(func: Callable) -> Callable: - """Decorator to execute the visit function in a separate context, in current state.""" - - return create_ctx_in_new_state()(func) +from .gtir_fieldview_builder import GtirFieldviewBuilder as FieldviewBuilder class GtirToSDFG(eve.NodeVisitor): @@ -74,7 +43,6 @@ class GtirToSDFG(eve.NodeVisitor): (e.g. a join state for an if/else branch execution) on the exit state of the program SDFG. """ - _ctx: Optional[FieldviewRegion] _param_types: list[ts.TypeSpec] def __init__( @@ -120,7 +88,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # define symbols for shape and offsets of temporary arrays as interstate edge symbols # TODO(edopao): use new `add_state_after` function in next dace release - head_state = sdfg.add_state_before(exit_state, "init_symbols_for_temporaries") + head_state = sdfg.add_state_before(exit_state, "init_temps") (sdfg.edges_between(entry_state, head_state))[0].assignments = temp_symbols else: head_state = entry_state @@ -129,116 +97,24 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for param, type_ in zip(node.params, self._param_types): self._add_storage(sdfg, str(param.id), type_) - self._ctx = FieldviewRegion(sdfg, head_state) # visit one statement at a time and put it into separate state - for stmt in node.body: - self.visit(stmt) - - assert self._ctx.state == head_state - self._ctx = None + for i, stmt in enumerate(node.body): + head_state = sdfg.add_state_before(exit_state, f"stmt_{i}") + self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() return sdfg - @create_ctx_in_new_state(new_state_name="set_at") - def visit_SetAt(self, stmt: itir.SetAt) -> None: + def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: """Visits a statement expression and writes the local result to some external storage. Each statement expression results in some sort of taskgraph writing to local (aka transient) storage. The translation of `SetAt` ensures that the result is written to the external storage. """ - stmt_ctx = self._ctx - assert stmt_ctx is not None - - self.visit(stmt.expr) + fieldview_builder = FieldviewBuilder(sdfg, state) + fieldview_builder.visit(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - # TODO: Use GtirToTasklet with new context without updating self._ctx - self._ctx = stmt_ctx.clone() - self.visit(stmt.target) - # the visit of a target expression should only produce a set of access nodes (no tasklets, no output nodes) - assert len(self._ctx.output_nodes) == 0 - stmt_ctx.output_nodes.extend(self._ctx.input_nodes) - self._ctx = stmt_ctx - - assert len(stmt_ctx.input_nodes) == len(stmt_ctx.output_nodes) - for tasklet_node, target_node in zip(stmt_ctx.input_nodes, stmt_ctx.output_nodes): - target_array = stmt_ctx.sdfg.arrays[target_node] - target_array.transient = False - - # TODO: visit statement domain to define the memlet subset - stmt_ctx.state.add_nedge( - stmt_ctx.node_mapping[tasklet_node], - stmt_ctx.node_mapping[target_node], - dace.Memlet.from_array(target_node, target_array), - ) - - @create_ctx - def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: - ctx = self._ctx - assert ctx is not None - - self.visit(fun_args) - - # create ordered list of input nodes - input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] - - # TODO: define shape based on domain and dtype based on type inference - shape = [10] - dtype = dace.float64 - output_name, output_array = ctx.sdfg.add_array( - ctx.var_name(), shape, dtype, transient=True, find_new_name=True - ) - output_arrays = [(output_name, output_array)] - - assert len(fun_node.args) == 1 - assert isinstance(fun_node.args[0], itir.Lambda) - - tletgen = GtirToTasklet(ctx) - tlet_code, tlet_inputs, tlet_outputs = tletgen.visit(fun_node.args[0]) - - # TODO: define map range based on domain - map_ranges = dict(i="0:10") - - input_memlets: dict[str, dace.Memlet] = {} - for connector, (dname, _) in zip(tlet_inputs, input_arrays): - # TODO: define memlet subset based on domain - input_memlets[connector] = dace.Memlet(data=dname, subset="i") - - output_memlets: dict[str, dace.Memlet] = {} - for connector, (dname, _) in zip(tlet_outputs, output_arrays): - # TODO: define memlet subset based on domain - output_memlets[connector] = dace.Memlet(data=dname, subset="i") - ctx.add_output_node(dname) - - ctx.state.add_mapped_tasklet( - ctx.tasklet_name(), - map_ranges, - input_memlets, - tlet_code, - output_memlets, - input_nodes=ctx.node_mapping, - output_nodes=ctx.node_mapping, - external_edges=True, - ) - - return ctx - - def visit_FunCall(self, node: itir.FunCall) -> None: - assert self._ctx is not None - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "as_fieldop": - child_ctx = self._make_fieldop(node.fun, node.args) - assert child_ctx.state == self._ctx.state - self._ctx.input_nodes.extend(child_ctx.output_nodes) - else: - raise NotImplementedError(f"Unexpected 'FunCall' with function {node.fun.fun.id}.") - else: - raise NotImplementedError(f"Unexpected 'FunCall' with type {type(node.fun)}.") - - def visit_SymRef(self, node: itir.SymRef) -> None: - dname = str(node.id) - assert self._ctx is not None - self._ctx.add_input_node(dname) + fieldview_builder.write_to(stmt.target) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py deleted file mode 100644 index e9c043369a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ /dev/null @@ -1,56 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -from typing import Sequence, Tuple - -import gt4py.eve as eve -from gt4py.next.iterator import ir as itir - -from .fieldview_dataflow import FieldviewRegion -from .gtir_tasklet_codegen import GtirTaskletCodegen - - -class GtirToTasklet(eve.NodeVisitor): - """Translates GTIR to Python code to be used as tasklet body. - - TODO: this class needs to be revisited in next commit. - """ - - _ctx: FieldviewRegion - - def __init__(self, ctx: FieldviewRegion): - self._ctx = ctx - - def _visit_deref(self, node: itir.FunCall) -> str: - # TODO: build memlet subset / shift pattern for each tasklet connector - if not isinstance(node.args[0], itir.SymRef): - raise NotImplementedError( - f"Unexpected 'deref' argument with type '{type(node.args[0])}'." - ) - return self.visit(node.args[0]) - - def visit_Lambda(self, node: itir.Lambda) -> Tuple[str, Sequence[str], Sequence[str]]: - params = [str(p.id) for p in node.params] - results = [] - - tlet_code_lines = [] - expr_list = GtirTaskletCodegen.apply(node.expr) - for i, expr in enumerate(expr_list): - outvar = f"__out_{i}" - results.append(outvar) - tlet_code_lines.append(outvar + " = " + expr) - tlet_code = "\n".join(tlet_code_lines) - - return tlet_code, params, results From 4e2dc15dd896cee6035acf0f102cdfd50e0c0197 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 22 Apr 2024 17:08:59 +0200 Subject: [PATCH 007/235] Add domain to field operator --- .../dace_fieldview/gtir_fieldview_builder.py | 74 +++++++++++++++---- .../runners/dace_fieldview/gtir_to_sdfg.py | 29 +++++--- .../runners/dace_fieldview/utility.py | 40 ++++++++++ .../runners_tests/test_dace_fieldview.py | 31 +++++--- 4 files changed, 138 insertions(+), 36 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/utility.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index a0d62f523d..b2cfc3aa9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -13,12 +13,15 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Sequence, Tuple import dace import gt4py.eve as eve +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type +from gt4py.next.type_system import type_specifications as ts from .fieldview_dataflow import FieldviewRegion from .gtir_tasklet_codegen import GtirTaskletCodegen as TaskletCodegen @@ -31,9 +34,13 @@ class GtirFieldviewBuilder(eve.NodeVisitor): """ _ctx: FieldviewRegion + _field_types: dict[str, ts.FieldType] - def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState): + def __init__( + self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] + ): self._ctx = FieldviewRegion(sdfg, state) + self._field_types = field_types.copy() @staticmethod def create_ctx(func: Callable) -> Callable: @@ -111,6 +118,27 @@ def write_to(self, node: itir.Expr) -> None: dace.Memlet.from_array(target_node, target_array), ) + def _visit_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: + assert isinstance(node.fun, itir.SymRef) + assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" + + domain = [] + translator = TaskletCodegen() + + for named_range in node.args: + assert isinstance(named_range, itir.FunCall) + assert isinstance(named_range.fun, itir.SymRef) + assert len(named_range.args) == 3 + dimension = named_range.args[0] + assert isinstance(dimension, itir.AxisLiteral) + lower_bound = named_range.args[1] + upper_bound = named_range.args[2] + lb = translator.visit(lower_bound) + ub = translator.visit(upper_bound) + domain.append((dimension.value, lb, ub)) + + return domain + @create_ctx def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: ctx = self._ctx @@ -120,33 +148,47 @@ def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> Fi # create ordered list of input nodes input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] - # TODO: define shape based on domain and dtype based on type inference - shape = [10] - dtype = dace.float64 + assert len(fun_node.args) == 2 + # expect stencil (represented as a lambda function) as first argument + self.visit(fun_node.args[0]) + # the domain of the field operator is passed as second argument + assert isinstance(fun_node.args[1], itir.FunCall) + domain = self._visit_domain(fun_node.args[1]) + map_ranges = {f"i_{d}": f"{lb}:{ub}" for d, lb, ub in domain} + me, mx = ctx.state.add_map("fieldop", map_ranges) + + # TODO: use type inference to determine the result type + type_ = ts.ScalarKind.FLOAT64 + dtype = as_dace_type(type_) + shape = [f"{ub} - {lb}" for _, lb, ub in domain] output_name, output_array = ctx.sdfg.add_array( ctx.var_name(), shape, dtype, transient=True, find_new_name=True ) output_arrays = [(output_name, output_array)] - - assert len(fun_node.args) == 1 - self.visit(fun_node.args[0]) - - # TODO: define map range based on domain - map_ranges = dict(i="0:10") - me, mx = ctx.state.add_map("fieldop", map_ranges) + self._field_types[output_name] = ts.FieldType( + dims=[Dimension(d) for d, _, _ in domain], dtype=ts.ScalarType(type_) + ) for (node, connector), (dname, _) in zip(self._ctx.input_connections, input_arrays): - # TODO: define memlet subset based on domain src_node = self._ctx.node_mapping[dname] + subset = ",".join([f"i_{d.value}" for d in self._field_types[dname].dims]) self._ctx.state.add_memlet_path( - src_node, me, node, dst_conn=connector, memlet=dace.Memlet(data=dname, subset="i") + src_node, + me, + node, + dst_conn=connector, + memlet=dace.Memlet(data=dname, subset=subset), ) for (node, connector), (dname, _) in zip(self._ctx.output_connections, output_arrays): - # TODO: define memlet subset based on domain dst_node = ctx.add_output_node(dname) + subset = ",".join([f"i_{d}" for d, _, _ in domain]) self._ctx.state.add_memlet_path( - node, mx, dst_node, src_conn=connector, memlet=dace.Memlet(data=dname, subset="i") + node, + mx, + dst_node, + src_conn=connector, + memlet=dace.Memlet(data=dname, subset=subset), ) return ctx diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 19949f78f4..3d8501aaff 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -26,6 +26,7 @@ from gt4py.next.type_system import type_specifications as ts from .gtir_fieldview_builder import GtirFieldviewBuilder as FieldviewBuilder +from .utility import as_dace_type class GtirToSDFG(eve.NodeVisitor): @@ -44,19 +45,26 @@ class GtirToSDFG(eve.NodeVisitor): """ _param_types: list[ts.TypeSpec] + _field_types: dict[str, ts.FieldType] def __init__( self, param_types: list[ts.TypeSpec], ): self._param_types = param_types + self._field_types = {} def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: - # TODO define shape based on domain and dtype based on type inference - shape = [10] - dtype = dace.float64 - sdfg.add_array(name, shape, dtype, transient=False) - return + if isinstance(type_, ts.FieldType): + # TODO define shape based on domain and dtype based on type inference + shape = [10] + dtype = dace.float64 + sdfg.add_array(name, shape, dtype, transient=False) + self._field_types[name] = type_ + else: + assert isinstance(type_, ts.ScalarType) + dtype = as_dace_type(type_.kind) + sdfg.add_symbol(name, dtype) def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Dict[str, str]: raise NotImplementedError() @@ -78,7 +86,6 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # we use entry/exit state to keep track of entry/exit point of graph execution entry_state = sdfg.add_state("program_entry", is_start_block=True) - exit_state = sdfg.add_state_after(entry_state, "program_exit") # declarations of temporaries result in local (aka transient) array definitions in the SDFG if node.declarations: @@ -88,19 +95,23 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # define symbols for shape and offsets of temporary arrays as interstate edge symbols # TODO(edopao): use new `add_state_after` function in next dace release - head_state = sdfg.add_state_before(exit_state, "init_temps") + head_state = sdfg.add_state_after(entry_state, "init_temps") (sdfg.edges_between(entry_state, head_state))[0].assignments = temp_symbols else: head_state = entry_state # add global arrays (aka non-transient) to the SDFG + assert len(node.params) == len(self._param_types) for param, type_ in zip(node.params, self._param_types): self._add_storage(sdfg, str(param.id), type_) # visit one statement at a time and put it into separate state for i, stmt in enumerate(node.body): - head_state = sdfg.add_state_before(exit_state, f"stmt_{i}") + head_state = sdfg.add_state_after(head_state, f"stmt_{i}") self.visit(stmt, sdfg=sdfg, state=head_state) + sink_states = sdfg.sink_nodes() + assert len(sink_states) == 1 + head_state = sink_states[0] sdfg.validate() return sdfg @@ -112,7 +123,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to the external storage. """ - fieldview_builder = FieldviewBuilder(sdfg, state) + fieldview_builder = FieldviewBuilder(sdfg, state, self._field_types) fieldview_builder.visit(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py new file mode 100644 index 0000000000..9c8f67e604 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -0,0 +1,40 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import itertools +from typing import Any, Mapping, Optional, Sequence + +import dace +import numpy as np + +import gt4py.eve as eve +from gt4py.next import Dimension, DimensionKind, type_inference as next_typing +from gt4py.next.common import Connectivity +from gt4py.next.iterator import ir as itir, type_inference as itir_typing +from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef +from gt4py.next.type_system import type_specifications as ts, type_translation as tt + +import dace + +def as_dace_type(type_: ts.ScalarKind) -> dace.dtypes.typeclass: + if type_ == ts.ScalarKind.BOOL: + return dace.bool_ + elif type_ == ts.ScalarKind.INT32: + return dace.int32 + elif type_ == ts.ScalarKind.INT64: + return dace.int64 + elif type_ == ts.ScalarKind.FLOAT32: + return dace.float32 + elif type_ == ts.ScalarKind.FLOAT64: + return dace.float64 + raise ValueError(f"Data type '{type_}' not supported.") \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 31a482d6b3..c2f7b63e9c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -39,18 +39,19 @@ def test_gtir_sum2(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") ) testee = itir.Program( id="sum_2fields", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="z")], + params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="z"), itir.Sym(id="size")], declarations=[], body=[ itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) )("x", "y"), domain=domain, @@ -64,36 +65,44 @@ def test_gtir_sum2(): c = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - param_types=([FTYPE] * 3), + [FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)] ) sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) - sdfg(x=a, y=b, z=c) + sdfg(x=a, y=b, z=c, size=N) assert np.allclose(c, (a + b)) def test_gtir_sum3(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, 10) + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") ) testee = itir.Program( id="sum_3fields", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="w"), itir.Sym(id="z")], + params=[ + itir.Sym(id="x"), + itir.Sym(id="y"), + itir.Sym(id="w"), + itir.Sym(id="z"), + itir.Sym(id="size"), + ], declarations=[], body=[ itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) )( "x", im.call( im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))) + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) )("y", "w"), ), @@ -109,11 +118,11 @@ def test_gtir_sum3(): d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - param_types=([FTYPE] * 4), + [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)] ) sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) - sdfg(x=a, y=b, w=c, z=d) + sdfg(x=a, y=b, w=c, z=d, size=N) assert np.allclose(d, (a + b + c)) From ea9da354472e408e3cb140052a3845eb66c54fa8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 22 Apr 2024 18:12:20 +0200 Subject: [PATCH 008/235] Minor edit --- .../dace_fieldview/gtir_fieldview_builder.py | 6 ++++-- .../runners/dace_fieldview/gtir_to_sdfg.py | 3 ++- .../runners/dace_fieldview/utility.py | 13 ++----------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index b2cfc3aa9a..87481d8aac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -118,7 +118,7 @@ def write_to(self, node: itir.Expr) -> None: dace.Memlet.from_array(target_node, target_array), ) - def _visit_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: + def _make_fieldop_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: assert isinstance(node.fun, itir.SymRef) assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" @@ -143,6 +143,8 @@ def _visit_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: ctx = self._ctx + # TODO: add early inspection of compute pattern and call specialized builder + self.visit(fun_args) # create ordered list of input nodes @@ -153,7 +155,7 @@ def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> Fi self.visit(fun_node.args[0]) # the domain of the field operator is passed as second argument assert isinstance(fun_node.args[1], itir.FunCall) - domain = self._visit_domain(fun_node.args[1]) + domain = self._make_fieldop_domain(fun_node.args[1]) map_ranges = {f"i_{d}": f"{lb}:{ub}" for d, lb, ub in domain} me, mx = ctx.state.add_map("fieldop", map_ranges) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 3d8501aaff..c14878205a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -96,7 +96,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # define symbols for shape and offsets of temporary arrays as interstate edge symbols # TODO(edopao): use new `add_state_after` function in next dace release head_state = sdfg.add_state_after(entry_state, "init_temps") - (sdfg.edges_between(entry_state, head_state))[0].assignments = temp_symbols + sdfg.edges_between(entry_state, head_state)[0].assignments = temp_symbols else: head_state = entry_state @@ -109,6 +109,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") self.visit(stmt, sdfg=sdfg, state=head_state) + # sanity check below: each statement should have a single exit state -- aka no branches sink_states = sdfg.sink_nodes() assert len(sink_states) == 1 head_state = sink_states[0] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 9c8f67e604..96b6d5a0b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -11,20 +11,11 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import itertools -from typing import Any, Mapping, Optional, Sequence import dace -import numpy as np -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, type_inference as next_typing -from gt4py.next.common import Connectivity -from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.type_system import type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_specifications as ts -import dace def as_dace_type(type_: ts.ScalarKind) -> dace.dtypes.typeclass: if type_ == ts.ScalarKind.BOOL: @@ -37,4 +28,4 @@ def as_dace_type(type_: ts.ScalarKind) -> dace.dtypes.typeclass: return dace.float32 elif type_ == ts.ScalarKind.FLOAT64: return dace.float64 - raise ValueError(f"Data type '{type_}' not supported.") \ No newline at end of file + raise ValueError(f"Data type '{type_}' not supported.") From daf7827dae90afabf4c032d168e8ef9c90d007d8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 23 Apr 2024 09:53:30 +0200 Subject: [PATCH 009/235] Remove hard-coded field shape --- .../dace_fieldview/gtir_fieldview_builder.py | 20 +++--- .../runners/dace_fieldview/gtir_to_sdfg.py | 69 ++++++++++++++----- .../runners/dace_fieldview/utility.py | 25 +++++-- .../runners_tests/test_dace_fieldview.py | 25 +++++-- 4 files changed, 102 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index 87481d8aac..6548f0ddd1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -42,6 +42,14 @@ def __init__( self._ctx = FieldviewRegion(sdfg, state) self._field_types = field_types.copy() + def _add_local_storage( + self, name: str, type_: ts.FieldType, shape: list[str] + ) -> Tuple[str, dace.data.Array]: + self._field_types[name] = type_ + dtype = as_dace_type(type_.dtype) + # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used + return self._ctx.sdfg.add_array(name, shape, dtype, transient=True, find_new_name=True) + @staticmethod def create_ctx(func: Callable) -> Callable: def newf( @@ -161,15 +169,11 @@ def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> Fi # TODO: use type inference to determine the result type type_ = ts.ScalarKind.FLOAT64 - dtype = as_dace_type(type_) - shape = [f"{ub} - {lb}" for _, lb, ub in domain] - output_name, output_array = ctx.sdfg.add_array( - ctx.var_name(), shape, dtype, transient=True, find_new_name=True - ) + field_dims = [Dimension(d) for d, _, _ in domain] + field_type = ts.FieldType(field_dims, ts.ScalarType(type_)) + field_shape = [f"{ub} - {lb}" for _, lb, ub in domain] + output_name, output_array = self._add_local_storage(ctx.var_name(), field_type, field_shape) output_arrays = [(output_name, output_array)] - self._field_types[output_name] = ts.FieldType( - dims=[Dimension(d) for d, _, _ in domain], dtype=ts.ScalarType(type_) - ) for (node, connector), (dname, _) in zip(self._ctx.input_connections, input_arrays): src_node = self._ctx.node_mapping[dname] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index c14878205a..6e08d7ae28 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -12,62 +12,99 @@ # # SPDX-License-Identifier: GPL-3.0-or-later """ -Class to lower GTIR to SDFG. +Class to lower GTIR to a DaCe SDFG. Note: this module covers the fieldview flavour of GTIR. """ -from typing import Dict +from typing import Any, Dict, Mapping, Sequence, Tuple import dace from gt4py import eve +from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts from .gtir_fieldview_builder import GtirFieldviewBuilder as FieldviewBuilder -from .utility import as_dace_type +from .utility import as_dace_type, filter_connectivities class GtirToSDFG(eve.NodeVisitor): - """Provides translation capability from an GTIR program to a DaCe SDFG. + """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation - of a GT4Py program as a sequence of `it.Stmt` statements. + of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. Each statement is translated to a taskgraph inside a separate state. The parent SDFG and - the translation state define the translation context, implemented by `ItirTaskgenContext`. + the translation state define the statement context, implemented by `FieldviewRegion`. Statement states are chained one after the other: potential concurrency between states should be extracted by the DaCe SDFG transformations. - The program translation keeps track of entry and exit states: each statement is translated as - a new state inserted just before the exit state. Note that statements with branch execution might - result in more than one state. However, each statement should provide a single termination state - (e.g. a join state for an if/else branch execution) on the exit state of the program SDFG. + The program translation keeps track of entry and exit states: each statement is supposed to extend + the SDFG but maintain the property of single exit state (that is no branching on leaf nodes). + Branching is allowed within the context of one statement, but in that case the statement should + terminate with a join state; the join state will represent the head state for next statement, + that is where to continue building the SDFG. """ _param_types: list[ts.TypeSpec] _field_types: dict[str, ts.FieldType] + _offset_providers: Mapping[str, Any] def __init__( self, param_types: list[ts.TypeSpec], + offset_providers: dict[str, Connectivity | Dimension], ): self._param_types = param_types self._field_types = {} + self._offset_providers = offset_providers + + def _make_array_shape_and_strides( + self, name: str, dims: Sequence[Dimension] + ) -> Tuple[Sequence[dace.symbol], Sequence[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns + ------- + tuple(shape, strides) + The output tuple fields are arrays of dace symbolic expressions. + """ + dtype = dace.int32 + neighbor_tables = filter_connectivities(self._offset_providers) + shape = [ + ( + # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain + else dace.symbol(f"__{name}_size_{i}", dtype) + ) + for i, dim in enumerate(dims) + ] + strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i, _ in enumerate(dims)] + return shape, strides def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: if isinstance(type_, ts.FieldType): - # TODO define shape based on domain and dtype based on type inference - shape = [10] - dtype = dace.float64 - sdfg.add_array(name, shape, dtype, transient=False) + dtype = as_dace_type(type_.dtype) + # use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, type_.dims) + sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) self._field_types[name] = type_ + else: assert isinstance(type_, ts.ScalarType) - dtype = as_dace_type(type_.kind) + dtype = as_dace_type(type_) + # scalar arguments passed to the program are represented as symbols in DaCe SDFG sdfg.add_symbol(name, dtype) def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Dict[str, str]: - raise NotImplementedError() + raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") return {} def visit_Program(self, node: itir.Program) -> dace.SDFG: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 96b6d5a0b0..b25521f5c3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -12,20 +12,31 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Any, Mapping + import dace +from gt4py.next.common import Connectivity from gt4py.next.type_system import type_specifications as ts -def as_dace_type(type_: ts.ScalarKind) -> dace.dtypes.typeclass: - if type_ == ts.ScalarKind.BOOL: +def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: + if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ - elif type_ == ts.ScalarKind.INT32: + elif type_.kind == ts.ScalarKind.INT32: return dace.int32 - elif type_ == ts.ScalarKind.INT64: + elif type_.kind == ts.ScalarKind.INT64: return dace.int64 - elif type_ == ts.ScalarKind.FLOAT32: + elif type_.kind == ts.ScalarKind.FLOAT32: return dace.float32 - elif type_ == ts.ScalarKind.FLOAT64: + elif type_.kind == ts.ScalarKind.FLOAT64: return dace.float64 - raise ValueError(f"Data type '{type_}' not supported.") + raise ValueError(f"Scalar type '{type_}' not supported.") + + +def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: + return { + offset: table + for offset, table in offset_provider.items() + if isinstance(table, Connectivity) + } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index c2f7b63e9c..03968857f5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -17,7 +17,8 @@ Note: this test module covers the fieldview flavour of ITIR. """ -from gt4py.next.common import Dimension +from typing import Union +from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( @@ -34,7 +35,19 @@ N = 10 DIM = Dimension("D") -FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarKind.FLOAT64) +FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +FSYMBOLS = dict( + __w_size_0=N, + __w_stride_0=1, + __x_size_0=N, + __x_stride_0=1, + __y_size_0=N, + __y_stride_0=1, + __z_size_0=N, + __z_stride_0=1, + size=N, +) +OFFSET_PROVIDERS: dict[str, Connectivity | Dimension] = {} def test_gtir_sum2(): @@ -65,13 +78,13 @@ def test_gtir_sum2(): c = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + [FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS ) sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) - sdfg(x=a, y=b, z=c, size=N) + sdfg(x=a, y=b, z=c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -118,11 +131,11 @@ def test_gtir_sum3(): d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS ) sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) - sdfg(x=a, y=b, w=c, z=d, size=N) + sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + c)) From 9672b3b2d125efc164af58b2fd6b124de36c56bf Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 23 Apr 2024 10:17:05 +0200 Subject: [PATCH 010/235] Remove hard-coded target domain --- .../runners/dace_fieldview/gtir_fieldview_builder.py | 10 ++++++---- .../runners/dace_fieldview/gtir_to_sdfg.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index 6548f0ddd1..0e430a60ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -106,10 +106,13 @@ def visit_SymRef(self, node: itir.SymRef) -> None: dname = str(node.id) self._ctx.add_input_node(dname) - def write_to(self, node: itir.Expr) -> None: + def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: result_nodes = self._ctx.input_nodes.copy() self._ctx = self._ctx.clone() - self.visit(node) + self.visit(target_expr) + assert isinstance(domain_expr, itir.FunCall) + domain = self._make_fieldop_domain(domain_expr) + write_subset = ",".join(f"{lb}:{ub}" for _, lb, ub in domain) # the target expression should only produce a set of access nodes (no tasklets, no output nodes) assert len(self._ctx.output_nodes) == 0 output_nodes = self._ctx.input_nodes @@ -119,11 +122,10 @@ def write_to(self, node: itir.Expr) -> None: target_array = self._ctx.sdfg.arrays[target_node] target_array.transient = False - # TODO: visit statement domain to define the memlet subset self._ctx.state.add_nedge( self._ctx.node_mapping[tasklet_node], self._ctx.node_mapping[target_node], - dace.Memlet.from_array(target_node, target_array), + dace.Memlet(data=target_node, subset=write_subset), ) def _make_fieldop_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 6e08d7ae28..fec68745f6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -166,4 +166,4 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - fieldview_builder.write_to(stmt.target) + fieldview_builder.write_to(stmt.target, stmt.domain) From 26f3790d4ba07a2e2a309030ac9b831e6335f093 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 24 Apr 2024 07:46:06 +0200 Subject: [PATCH 011/235] Refactoring --- .../dace_fieldview/gtir_fieldview_builder.py | 272 ++++++++---------- ..._dataflow.py => gtir_fieldview_context.py} | 31 +- .../dace_fieldview/gtir_tasklet_arithmetic.py | 115 ++++++++ .../dace_fieldview/gtir_tasklet_codegen.py | 137 ++++----- 4 files changed, 312 insertions(+), 243 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{fieldview_dataflow.py => gtir_fieldview_context.py} (77%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index 0e430a60ac..c41d178e2a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -13,112 +13,163 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import dace import gt4py.eve as eve from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts -from .fieldview_dataflow import FieldviewRegion -from .gtir_tasklet_codegen import GtirTaskletCodegen as TaskletCodegen +from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext +from .gtir_tasklet_arithmetic import GtirTaskletArithmetic +from .gtir_tasklet_codegen import ( + GtirTaskletCodegen as TaskletCodegen, + GtirTaskletSubgraph as TaskletSubgraph, +) + + +def _make_fieldop( + ctx: FieldviewContext, + tasklet_subgraph: TaskletSubgraph, + domain: Sequence[Tuple[str, str, str]], +) -> None: + # create ordered list of input nodes + input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] + assert len(tasklet_subgraph.input_connections) == len(input_arrays) + + # TODO: use type inference to determine the result type + type_ = ts.ScalarKind.FLOAT64 + output_arrays: list[Tuple[str, dace.data.Array]] = [] + for _, field_type in tasklet_subgraph.output_connections: + if field_type is None: + field_dims = [Dimension(d) for d, _, _ in domain] + field_type = ts.FieldType(field_dims, ts.ScalarType(type_)) + field_shape = [f"{ub} - {lb}" for d, lb, ub in domain if Dimension(d) in field_type.dims] + output_name, output_array = ctx.add_local_storage(ctx.var_name(), field_type, field_shape) + output_arrays.append((output_name, output_array)) + + map_ranges = {f"i_{d}": f"{lb}:{ub}" for d, lb, ub in domain} + me, mx = ctx.state.add_map("fieldop", map_ranges) + + for (connector, field_type), (dname, _) in zip( + tasklet_subgraph.input_connections, input_arrays + ): + src_node = ctx.node_mapping[dname] + if field_type is None: + subset = ",".join([f"i_{d.value}" for d in ctx.field_types[dname].dims]) + else: + raise NotImplementedError("Array subset on tasklet connector not supported.") + ctx.state.add_memlet_path( + src_node, + me, + tasklet_subgraph.node, + dst_conn=connector, + memlet=dace.Memlet(data=dname, subset=subset), + ) + for (connector, field_type), (dname, _) in zip( + tasklet_subgraph.output_connections, output_arrays + ): + dst_node = ctx.add_output_node(dname) + if field_type is None: + subset = ",".join([f"i_{d}" for d, _, _ in domain]) + else: + raise NotImplementedError("Array subset on tasklet connector not supported.") + ctx.state.add_memlet_path( + tasklet_subgraph.node, + mx, + dst_node, + src_conn=connector, + memlet=dace.Memlet(data=dname, subset=subset), + ) -class GtirFieldviewBuilder(eve.NodeVisitor): - """Translates GTIR to Python code to be used as tasklet body. - TODO: this class needs to be revisited in next commit. - """ +def _make_fieldop_domain( + ctx: FieldviewContext, node: itir.FunCall +) -> Sequence[Tuple[str, str, str]]: + assert cpm.is_call_to(node, ["cartesian_domain", "unstructured_domain"]) - _ctx: FieldviewRegion - _field_types: dict[str, ts.FieldType] + domain = [] + translator = TaskletCodegen(ctx) - def __init__( - self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] - ): - self._ctx = FieldviewRegion(sdfg, state) - self._field_types = field_types.copy() + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + dimension = named_range.args[0] + assert isinstance(dimension, itir.AxisLiteral) + lower_bound = named_range.args[1] + upper_bound = named_range.args[2] + lb = translator.visit(lower_bound) + ub = translator.visit(upper_bound) + domain.append((dimension.value, lb, ub)) - def _add_local_storage( - self, name: str, type_: ts.FieldType, shape: list[str] - ) -> Tuple[str, dace.data.Array]: - self._field_types[name] = type_ - dtype = as_dace_type(type_.dtype) - # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used - return self._ctx.sdfg.add_array(name, shape, dtype, transient=True, find_new_name=True) + return domain - @staticmethod - def create_ctx(func: Callable) -> Callable: - def newf( - self: "GtirFieldviewBuilder", *args: Any, **kwargs: Optional[Any] - ) -> FieldviewRegion: - prev_ctx = self._ctx - new_ctx = prev_ctx.clone() - self._ctx = new_ctx - child_ctx = func(self, *args, **kwargs) +class GtirFieldviewBuilder(eve.NodeVisitor): + """Translates GTIR fieldview operator to some kind of map scope in DaCe SDFG.""" - assert self._ctx == new_ctx - self._ctx = prev_ctx + _ctx: FieldviewContext + _registered_taskgens: list[TaskletCodegen] = [GtirTaskletArithmetic] - return child_ctx + def __init__( + self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] + ): + self._ctx = FieldviewContext(sdfg, state, field_types.copy()) - return newf + def _get_tasklet_codegen(self, lambda_node: itir.Lambda) -> Optional[TaskletCodegen]: + for taskgen in self._registered_taskgens: + if taskgen.can_handle(lambda_node): + return taskgen(self._ctx) + return None def visit_FunCall(self, node: itir.FunCall) -> None: - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "as_fieldop": - child_ctx = self._make_fieldop(node.fun, node.args) - assert child_ctx.state == self._ctx.state - self._ctx.input_nodes.extend(child_ctx.output_nodes) - else: - raise NotImplementedError(f"Unexpected 'FunCall' with function {node.fun.fun.id}.") + parent_ctx = self._ctx + child_ctx = parent_ctx.clone() + self._ctx = child_ctx + + if cpm.is_call_to(node.fun, "as_fieldop"): + fun_node = node.fun + assert len(fun_node.args) == 2 + # expect stencil (represented as a lambda function) as first argument + assert isinstance(fun_node.args[0], itir.Lambda) + taskgen = self._get_tasklet_codegen(fun_node.args[0]) + if not taskgen: + raise NotImplementedError(f"Failed to lower 'as_fieldop' node to SDFG ({node}).") + # the domain of the field operator is passed as second argument + assert isinstance(fun_node.args[1], itir.FunCall) + domain = _make_fieldop_domain(self._ctx, fun_node.args[1]) + + self.visit(node.args) + tasklet_subgraph = taskgen.visit(fun_node.args[0]) + _make_fieldop(self._ctx, tasklet_subgraph, domain) else: - raise NotImplementedError(f"Unexpected 'FunCall' with type {type(node.fun)}.") + raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - def visit_Lambda(self, node: itir.Lambda) -> None: - params = [str(p.id) for p in node.params] - results = [] - - tlet_code_lines = [] - expr_list = TaskletCodegen.apply(node.expr) - - for i, expr in enumerate(expr_list): - outvar = f"__out_{i}" - tlet_code_lines.append(outvar + " = " + expr) - results.append(outvar) - tlet_code = "\n".join(tlet_code_lines) - - tlet_node: dace.tasklet = self._ctx.state.add_tasklet( - f"{self._ctx.state.label}_lambda", set(params), set(results), tlet_code - ) - - # TODO: distinguish between external and local connections (now assume all external) - for inpvar in params: - self._ctx.input_connections.append((tlet_node, inpvar)) - for outvar in results: - self._ctx.output_connections.append((tlet_node, outvar)) + assert self._ctx == child_ctx + parent_ctx.input_nodes.extend(child_ctx.output_nodes) + self._ctx = parent_ctx def visit_SymRef(self, node: itir.SymRef) -> None: dname = str(node.id) self._ctx.add_input_node(dname) def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: - result_nodes = self._ctx.input_nodes.copy() - self._ctx = self._ctx.clone() - self.visit(target_expr) + assert len(self._ctx.output_nodes) == 0 + + # TODO: add support for tuple return + assert len(self._ctx.input_nodes) == 1 + assert isinstance(target_expr, itir.SymRef) + self._ctx.add_output_node(target_expr.id) + assert isinstance(domain_expr, itir.FunCall) - domain = self._make_fieldop_domain(domain_expr) + domain = _make_fieldop_domain(self._ctx, domain_expr) write_subset = ",".join(f"{lb}:{ub}" for _, lb, ub in domain) - # the target expression should only produce a set of access nodes (no tasklets, no output nodes) - assert len(self._ctx.output_nodes) == 0 - output_nodes = self._ctx.input_nodes - assert len(result_nodes) == len(output_nodes) - for tasklet_node, target_node in zip(result_nodes, output_nodes): + for tasklet_node, target_node in zip(self._ctx.input_nodes, self._ctx.output_nodes): target_array = self._ctx.sdfg.arrays[target_node] target_array.transient = False @@ -127,76 +178,3 @@ def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: self._ctx.node_mapping[target_node], dace.Memlet(data=target_node, subset=write_subset), ) - - def _make_fieldop_domain(self, node: itir.FunCall) -> Sequence[Tuple[str, str, str]]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - domain = [] - translator = TaskletCodegen() - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - lb = translator.visit(lower_bound) - ub = translator.visit(upper_bound) - domain.append((dimension.value, lb, ub)) - - return domain - - @create_ctx - def _make_fieldop(self, fun_node: itir.FunCall, fun_args: List[itir.Expr]) -> FieldviewRegion: - ctx = self._ctx - - # TODO: add early inspection of compute pattern and call specialized builder - - self.visit(fun_args) - - # create ordered list of input nodes - input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] - - assert len(fun_node.args) == 2 - # expect stencil (represented as a lambda function) as first argument - self.visit(fun_node.args[0]) - # the domain of the field operator is passed as second argument - assert isinstance(fun_node.args[1], itir.FunCall) - domain = self._make_fieldop_domain(fun_node.args[1]) - map_ranges = {f"i_{d}": f"{lb}:{ub}" for d, lb, ub in domain} - me, mx = ctx.state.add_map("fieldop", map_ranges) - - # TODO: use type inference to determine the result type - type_ = ts.ScalarKind.FLOAT64 - field_dims = [Dimension(d) for d, _, _ in domain] - field_type = ts.FieldType(field_dims, ts.ScalarType(type_)) - field_shape = [f"{ub} - {lb}" for _, lb, ub in domain] - output_name, output_array = self._add_local_storage(ctx.var_name(), field_type, field_shape) - output_arrays = [(output_name, output_array)] - - for (node, connector), (dname, _) in zip(self._ctx.input_connections, input_arrays): - src_node = self._ctx.node_mapping[dname] - subset = ",".join([f"i_{d.value}" for d in self._field_types[dname].dims]) - self._ctx.state.add_memlet_path( - src_node, - me, - node, - dst_conn=connector, - memlet=dace.Memlet(data=dname, subset=subset), - ) - - for (node, connector), (dname, _) in zip(self._ctx.output_connections, output_arrays): - dst_node = ctx.add_output_node(dname) - subset = ",".join([f"i_{d}" for d, _, _ in domain]) - self._ctx.state.add_memlet_path( - node, - mx, - dst_node, - src_conn=connector, - memlet=dace.Memlet(data=dname, subset=subset), - ) - - return ctx diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py similarity index 77% rename from src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py index 77f84108c5..c709a8292b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/fieldview_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py @@ -13,12 +13,16 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Optional, Tuple, TypeAlias +from typing import Tuple import dace +from gt4py.next.type_system import type_specifications as ts -class FieldviewRegion: +from .utility import as_dace_type + + +class GtirFieldviewContext: """Defines the dataflow scope of a fieldview expression. This class defines a region of the dataflow which represents a fieldview expression. @@ -34,16 +38,11 @@ class FieldviewRegion: with all informatiion needed to construct the dataflow graph. """ - Connection: TypeAlias = Tuple[dace.nodes.Node, Optional[str]] - sdfg: dace.SDFG state: dace.SDFGState + field_types: dict[str, ts.FieldType] node_mapping: dict[str, dace.nodes.AccessNode] - # ordered list of input/output data nodes used by the field operator being built in this dataflow region - input_connections: list[Connection] - output_connections: list[Connection] - input_nodes: list[str] output_nodes: list[str] @@ -51,14 +50,14 @@ def __init__( self, current_sdfg: dace.SDFG, current_state: dace.SDFGState, + current_field_types: dict[str, ts.FieldType], ): self.sdfg = current_sdfg self.state = current_state + self.field_types = current_field_types self.node_mapping = {} self.input_nodes = [] self.output_nodes = [] - self.input_connections = [] - self.output_connections = [] def _add_node(self, data: str) -> dace.nodes.AccessNode: assert data in self.sdfg.arrays @@ -77,8 +76,16 @@ def add_output_node(self, data: str) -> dace.nodes.AccessNode: self.output_nodes.append(data) return self._add_node(data) - def clone(self) -> "FieldviewRegion": - ctx = FieldviewRegion(self.sdfg, self.state) + def add_local_storage( + self, name: str, type_: ts.FieldType, shape: list[str] + ) -> Tuple[str, dace.data.Array]: + self.field_types[name] = type_ + dtype = as_dace_type(type_.dtype) + # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used + return self.sdfg.add_transient(name, shape, dtype, find_new_name=True) + + def clone(self) -> "GtirFieldviewContext": + ctx = GtirFieldviewContext(self.sdfg, self.state, self.field_types) ctx.node_mapping = self.node_mapping return ctx diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py new file mode 100644 index 0000000000..06ffc9d17f --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py @@ -0,0 +1,115 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import numpy as np + +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm + +from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext +from .gtir_tasklet_codegen import GtirTaskletCodegen + + +_MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +class GtirTaskletArithmetic(GtirTaskletCodegen): + """Translates GTIR lambda exprressions with arithmetic builtin.""" + + def __init__(self, ctx: FieldviewContext): + super().__init__(ctx) + + @staticmethod + def can_handle(lambda_node: itir.Lambda) -> bool: + fun_node = lambda_node.expr + assert isinstance(fun_node, itir.FunCall) + if isinstance(fun_node.fun, itir.SymRef): + builtin_name = str(fun_node.fun.id) + return builtin_name in _MATH_BUILTINS_MAPPING + return False + + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + return self._visit_numeric_builtin(node) + + def visit_SymRef(self, node: itir.SymRef) -> str: + name = str(node.id) + self._input_connections.append((name, None)) + return name diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 9b85e13650..8d82e2e621 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -12,98 +12,67 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Sequence +import dataclasses +from typing import Optional, Tuple -import numpy as np +import dace from gt4py.eve import codegen from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts +from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} + +@dataclasses.dataclass(frozen=True) +class GtirTaskletSubgraph: + """Defines a tasklet subgraph representing a stencil expression. + + The tasklet subgraph will be used by the consumer to build a fieldview expression. + For example, it could be used in a map scope to build a fieldview expression; + or it could become the body of a scan expression. + """ + + # generic DaCe node, most often this will be a tasklet node but it could also be a nested SDFG + node: dace.nodes.Node + + # for each input/output connections, specify the field type or None if scalar + input_connections: list[Tuple[str, Optional[ts.FieldType]]] + output_connections: list[Tuple[str, Optional[ts.FieldType]]] class GtirTaskletCodegen(codegen.TemplatedGenerator): - """Translates GTIR to Python code to be used as tasklet body. + """Base class to translate GTIR to Python code to be used as tasklet body.""" - This class is dace agnostic: it receives GTIR as input and produces Python code. - """ + _ctx: FieldviewContext + # list of input/output connectors and expected field type (None if scalar) + _input_connections: list[Tuple[str, Optional[ts.FieldType]]] + _output_connections: list[Tuple[str, Optional[ts.FieldType]]] + + def __init__(self, ctx: FieldviewContext): + self._ctx = ctx + self._input_connections = [] + self._output_connections = [] + + @staticmethod + def can_handle(lambda_node: itir.Lambda) -> bool: + raise NotImplementedError("") + + def visit_Lambda(self, node: itir.Lambda) -> GtirTaskletSubgraph: + tlet_expr = self.visit(node.expr) + + params = [str(p.id) for p in node.params] + assert len(self._input_connections) == len(params) + + outvar = "__out" + tlet_code = f"{outvar} = {tlet_expr}" + results = [outvar] + self._output_connections.append((outvar, None)) + + tlet_node: dace.tasklet = self._ctx.state.add_tasklet( + f"{self._ctx.tasklet_name()}_lambda", set(params), set(results), tlet_code + ) + + subgraph = GtirTaskletSubgraph(tlet_node, self._input_connections, self._output_connections) - def _visit_deref(self, node: itir.FunCall) -> list[str]: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> Sequence[str]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - expr = fmt.format(*args) - return [expr] - - def visit_FunCall(self, node: itir.FunCall) -> Sequence[str]: - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' with type '{type(node.fun)}'.") - - def visit_SymRef(self, node: itir.SymRef) -> str: - return str(node.id) + return subgraph From f99fa84f896771b22f587c87bf3ac1b2451cfd59 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 24 Apr 2024 11:09:21 +0200 Subject: [PATCH 012/235] Fix formatting --- .../runners/dace_fieldview/gtir_fieldview_builder.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py index c41d178e2a..4efc49354d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py @@ -31,6 +31,11 @@ ) +REGISTERED_TASKGENS: list[type[TaskletCodegen]] = [ + GtirTaskletArithmetic, +] + + def _make_fieldop( ctx: FieldviewContext, tasklet_subgraph: TaskletSubgraph, @@ -113,7 +118,6 @@ class GtirFieldviewBuilder(eve.NodeVisitor): """Translates GTIR fieldview operator to some kind of map scope in DaCe SDFG.""" _ctx: FieldviewContext - _registered_taskgens: list[TaskletCodegen] = [GtirTaskletArithmetic] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] @@ -121,7 +125,7 @@ def __init__( self._ctx = FieldviewContext(sdfg, state, field_types.copy()) def _get_tasklet_codegen(self, lambda_node: itir.Lambda) -> Optional[TaskletCodegen]: - for taskgen in self._registered_taskgens: + for taskgen in REGISTERED_TASKGENS: if taskgen.can_handle(lambda_node): return taskgen(self._ctx) return None @@ -138,7 +142,7 @@ def visit_FunCall(self, node: itir.FunCall) -> None: assert isinstance(fun_node.args[0], itir.Lambda) taskgen = self._get_tasklet_codegen(fun_node.args[0]) if not taskgen: - raise NotImplementedError(f"Failed to lower 'as_fieldop' node to SDFG ({node}).") + raise NotImplementedError(f"Unsupported 'as_fieldop' node ({node}).") # the domain of the field operator is passed as second argument assert isinstance(fun_node.args[1], itir.FunCall) domain = _make_fieldop_domain(self._ctx, fun_node.args[1]) From d6e1088e65f825af2d9f0fa950d728df1f67679f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 24 Apr 2024 11:29:11 +0200 Subject: [PATCH 013/235] More refactoring --- ...ew_builder.py => gtir_dataflow_builder.py} | 24 ++++++++--------- ...ew_context.py => gtir_dataflow_context.py} | 12 ++++----- .../dace_fieldview/gtir_tasklet_arithmetic.py | 4 +-- .../dace_fieldview/gtir_tasklet_codegen.py | 26 ++++++++++++------- .../runners/dace_fieldview/gtir_to_sdfg.py | 10 +++---- 5 files changed, 40 insertions(+), 36 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_fieldview_builder.py => gtir_dataflow_builder.py} (91%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_fieldview_context.py => gtir_dataflow_context.py} (92%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py similarity index 91% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 4efc49354d..e44b8e3bd4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import dace @@ -23,7 +23,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts -from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext +from .gtir_dataflow_context import GtirDataflowContext as DataflowContext from .gtir_tasklet_arithmetic import GtirTaskletArithmetic from .gtir_tasklet_codegen import ( GtirTaskletCodegen as TaskletCodegen, @@ -37,9 +37,9 @@ def _make_fieldop( - ctx: FieldviewContext, + ctx: DataflowContext, tasklet_subgraph: TaskletSubgraph, - domain: Sequence[Tuple[str, str, str]], + domain: Sequence[tuple[str, str, str]], ) -> None: # create ordered list of input nodes input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] @@ -47,7 +47,7 @@ def _make_fieldop( # TODO: use type inference to determine the result type type_ = ts.ScalarKind.FLOAT64 - output_arrays: list[Tuple[str, dace.data.Array]] = [] + output_arrays: list[tuple[str, dace.data.Array]] = [] for _, field_type in tasklet_subgraph.output_connections: if field_type is None: field_dims = [Dimension(d) for d, _, _ in domain] @@ -93,8 +93,8 @@ def _make_fieldop( def _make_fieldop_domain( - ctx: FieldviewContext, node: itir.FunCall -) -> Sequence[Tuple[str, str, str]]: + ctx: DataflowContext, node: itir.FunCall +) -> Sequence[tuple[str, str, str]]: assert cpm.is_call_to(node, ["cartesian_domain", "unstructured_domain"]) domain = [] @@ -114,15 +114,15 @@ def _make_fieldop_domain( return domain -class GtirFieldviewBuilder(eve.NodeVisitor): - """Translates GTIR fieldview operator to some kind of map scope in DaCe SDFG.""" +class GtirDataflowBuilder(eve.NodeVisitor): + """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" - _ctx: FieldviewContext + _ctx: DataflowContext def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] ): - self._ctx = FieldviewContext(sdfg, state, field_types.copy()) + self._ctx = DataflowContext(sdfg, state, field_types.copy()) def _get_tasklet_codegen(self, lambda_node: itir.Lambda) -> Optional[TaskletCodegen]: for taskgen in REGISTERED_TASKGENS: @@ -148,7 +148,7 @@ def visit_FunCall(self, node: itir.FunCall) -> None: domain = _make_fieldop_domain(self._ctx, fun_node.args[1]) self.visit(node.args) - tasklet_subgraph = taskgen.visit(fun_node.args[0]) + tasklet_subgraph = taskgen.build_stencil(fun_node.args[0]) _make_fieldop(self._ctx, tasklet_subgraph, domain) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py similarity index 92% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py index c709a8292b..974146f543 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_fieldview_context.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py @@ -13,8 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Tuple - import dace from gt4py.next.type_system import type_specifications as ts @@ -22,8 +20,8 @@ from .utility import as_dace_type -class GtirFieldviewContext: - """Defines the dataflow scope of a fieldview expression. +class GtirDataflowContext: + """Defines the SDFG subgraph scope of a fieldview expression. This class defines a region of the dataflow which represents a fieldview expression. It usually consists of a map scope, with a set of input nodes that traverse the entry map; @@ -78,14 +76,14 @@ def add_output_node(self, data: str) -> dace.nodes.AccessNode: def add_local_storage( self, name: str, type_: ts.FieldType, shape: list[str] - ) -> Tuple[str, dace.data.Array]: + ) -> tuple[str, dace.data.Array]: self.field_types[name] = type_ dtype = as_dace_type(type_.dtype) # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used return self.sdfg.add_transient(name, shape, dtype, find_new_name=True) - def clone(self) -> "GtirFieldviewContext": - ctx = GtirFieldviewContext(self.sdfg, self.state, self.field_types) + def clone(self) -> "GtirDataflowContext": + ctx = GtirDataflowContext(self.sdfg, self.state, self.field_types) ctx.node_mapping = self.node_mapping return ctx diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py index 06ffc9d17f..78afb68ace 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py @@ -18,7 +18,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext +from .gtir_dataflow_context import GtirDataflowContext as DataflowContext from .gtir_tasklet_codegen import GtirTaskletCodegen @@ -80,7 +80,7 @@ class GtirTaskletArithmetic(GtirTaskletCodegen): """Translates GTIR lambda exprressions with arithmetic builtin.""" - def __init__(self, ctx: FieldviewContext): + def __init__(self, ctx: DataflowContext): super().__init__(ctx) @staticmethod diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 8d82e2e621..a14660eb2c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -from typing import Optional, Tuple +from typing import Optional, final import dace @@ -21,7 +21,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts -from .gtir_fieldview_context import GtirFieldviewContext as FieldviewContext +from .gtir_dataflow_context import GtirDataflowContext as DataflowContext @dataclasses.dataclass(frozen=True) @@ -37,28 +37,29 @@ class GtirTaskletSubgraph: node: dace.nodes.Node # for each input/output connections, specify the field type or None if scalar - input_connections: list[Tuple[str, Optional[ts.FieldType]]] - output_connections: list[Tuple[str, Optional[ts.FieldType]]] + input_connections: list[tuple[str, Optional[ts.FieldType]]] + output_connections: list[tuple[str, Optional[ts.FieldType]]] class GtirTaskletCodegen(codegen.TemplatedGenerator): """Base class to translate GTIR to Python code to be used as tasklet body.""" - _ctx: FieldviewContext + _ctx: DataflowContext # list of input/output connectors and expected field type (None if scalar) - _input_connections: list[Tuple[str, Optional[ts.FieldType]]] - _output_connections: list[Tuple[str, Optional[ts.FieldType]]] + _input_connections: list[tuple[str, Optional[ts.FieldType]]] + _output_connections: list[tuple[str, Optional[ts.FieldType]]] - def __init__(self, ctx: FieldviewContext): + def __init__(self, ctx: DataflowContext): self._ctx = ctx self._input_connections = [] self._output_connections = [] @staticmethod def can_handle(lambda_node: itir.Lambda) -> bool: - raise NotImplementedError("") + raise NotImplementedError - def visit_Lambda(self, node: itir.Lambda) -> GtirTaskletSubgraph: + @final + def build_stencil(self, node: itir.Lambda) -> GtirTaskletSubgraph: tlet_expr = self.visit(node.expr) params = [str(p.id) for p in node.params] @@ -76,3 +77,8 @@ def visit_Lambda(self, node: itir.Lambda) -> GtirTaskletSubgraph: subgraph = GtirTaskletSubgraph(tlet_node, self._input_connections, self._output_connections) return subgraph + + @final + def visit_Lambda(self, node: itir.Lambda) -> GtirTaskletSubgraph: + # This visitor class should never encounter `itir.Lambda` expressionsß + raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index fec68745f6..9281864cc0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -17,7 +17,7 @@ Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Dict, Mapping, Sequence, Tuple +from typing import Any, Mapping, Sequence import dace @@ -26,7 +26,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts -from .gtir_fieldview_builder import GtirFieldviewBuilder as FieldviewBuilder +from .gtir_dataflow_builder import GtirDataflowBuilder as DataflowBuilder from .utility import as_dace_type, filter_connectivities @@ -61,7 +61,7 @@ def __init__( def _make_array_shape_and_strides( self, name: str, dims: Sequence[Dimension] - ) -> Tuple[Sequence[dace.symbol], Sequence[dace.symbol]]: + ) -> tuple[Sequence[dace.symbol], Sequence[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -103,7 +103,7 @@ def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: # scalar arguments passed to the program are represented as symbols in DaCe SDFG sdfg.add_symbol(name, dtype) - def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Dict[str, str]: + def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, str]: raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") return {} @@ -161,7 +161,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to the external storage. """ - fieldview_builder = FieldviewBuilder(sdfg, state, self._field_types) + fieldview_builder = DataflowBuilder(sdfg, state, self._field_types) fieldview_builder.visit(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression From 98544978e9895b54d21c595bddad2c97a380636b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 24 Apr 2024 13:23:44 +0200 Subject: [PATCH 014/235] Minor edit --- .../dace_fieldview/gtir_dataflow_builder.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index e44b8e3bd4..0c3cddb409 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -41,6 +41,12 @@ def _make_fieldop( tasklet_subgraph: TaskletSubgraph, domain: Sequence[tuple[str, str, str]], ) -> None: + """Put a `TaskletSubgraph` object into a map scope. + + This helper method represents a field operator as mapped tasklet. + The map range is given by the domain. + """ + # create ordered list of input nodes input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] assert len(tasklet_subgraph.input_connections) == len(input_arrays) @@ -95,6 +101,12 @@ def _make_fieldop( def _make_fieldop_domain( ctx: DataflowContext, node: itir.FunCall ) -> Sequence[tuple[str, str, str]]: + """Visits the domain of a field operator. + + Returns + ------- + A list of tuples(dimension_name, lower_bound_value, upper_bound_value) + """ assert cpm.is_call_to(node, ["cartesian_domain", "unstructured_domain"]) domain = [] @@ -162,6 +174,10 @@ def visit_SymRef(self, node: itir.SymRef) -> None: self._ctx.add_input_node(dname) def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: + """Write the current set of input nodes to external nodes. + + The target arrays are supposed to be external, therefore non-transient. + """ assert len(self._ctx.output_nodes) == 0 # TODO: add support for tuple return @@ -175,7 +191,7 @@ def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: for tasklet_node, target_node in zip(self._ctx.input_nodes, self._ctx.output_nodes): target_array = self._ctx.sdfg.arrays[target_node] - target_array.transient = False + assert target_array.transient == False self._ctx.state.add_nedge( self._ctx.node_mapping[tasklet_node], From 37d83d76c8a9fd766f08f0de5116b2e83b4997a2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 24 Apr 2024 13:39:52 +0200 Subject: [PATCH 015/235] Fix formatting --- .../runners/dace_fieldview/gtir_dataflow_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 0c3cddb409..7d05347f5a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -191,7 +191,7 @@ def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: for tasklet_node, target_node in zip(self._ctx.input_nodes, self._ctx.output_nodes): target_array = self._ctx.sdfg.arrays[target_node] - assert target_array.transient == False + assert not target_array.transient self._ctx.state.add_nedge( self._ctx.node_mapping[tasklet_node], From 29986ef121601bf0fbf273eb0e9d5de7ededc4b1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 29 Apr 2024 14:48:48 +0200 Subject: [PATCH 016/235] Use callable to build taskgraph --- .../dace_fieldview/gtir_builtin_translator.py | 126 ++++++++ .../dace_fieldview/gtir_dataflow_builder.py | 288 +++++++++--------- .../dace_fieldview/gtir_dataflow_context.py | 94 ------ .../dace_fieldview/gtir_tasklet_arithmetic.py | 115 ------- .../dace_fieldview/gtir_tasklet_codegen.py | 167 +++++++--- .../runners/dace_fieldview/gtir_to_sdfg.py | 25 +- 6 files changed, 405 insertions(+), 410 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py new file mode 100644 index 0000000000..4732472e0f --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py @@ -0,0 +1,126 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from typing import Sequence + +import dace + +from gt4py.next.common import Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( + GtirTaskletCodegen, +) +from gt4py.next.type_system import type_specifications as ts + + +class GtirBuiltinAsFieldOp(GtirTaskletCodegen): + _stencil: itir.Lambda + _domain: dict[Dimension, tuple[str, str]] + _args: Sequence[GtirTaskletCodegen] + _field_type: ts.FieldType + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + stencil: itir.Lambda, + domain: Sequence[tuple[Dimension, str, str]], + args: Sequence[GtirTaskletCodegen], + field_dtype: ts.ScalarType, + ): + super().__init__(sdfg, state) + self._stencil = stencil + self._args = args + self._domain = {dim: (lb, ub) for dim, lb, ub in domain} + self._field_type = ts.FieldType([dim for dim, _, _ in domain], field_dtype) + + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: + # generate the python code for this stencil + output_connector = "__out" + tlet_code = "{var} = {code}".format( + var=output_connector, code=self.visit(self._stencil.expr) + ) + + # allocate local (aka transient) storage for the field + field_shape = [ + # diff between upper and lower bound + f"{self._domain[dim][1]} - {self._domain[dim][0]}" + for dim in self._field_type.dims + ] + field_name, _ = self._add_local_storage(self._field_type, field_shape) + field_node = self._state.add_access(field_name) + + # create map range corresponding to the field operator domain + map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self._domain.items()} + + # visit expressions passed as arguments to this stencil + input_nodes: dict[str, dace.nodes.AccessNode] = {} + input_memlets: dict[str, dace.Memlet] = {} + assert len(self._args) == len(self._stencil.params) + for arg, param in zip(self._args, self._stencil.params): + arg_nodes = arg() + assert len(arg_nodes) == 1 + arg_node, arg_type = arg_nodes[0] + connector = str(param.id) + # require (for now) all input nodes to be data access nodes + assert isinstance(arg_node, dace.nodes.AccessNode) + input_nodes[arg_node.data] = arg_node + # support either single element access (general case) or full array shape + is_scalar = all(dim in self._domain for dim in arg_type.dims) + if is_scalar: + subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) + input_memlets[connector] = dace.Memlet(data=arg_node.data, subset=subset) + else: + input_memlets[connector] = dace.Memlet.from_array( + arg_node.data, arg_node.desc(self._sdfg) + ) + + # assume tasklet with single output + output_index = ",".join(f"i_{dim.value}" for dim in self._field_type.dims) + output_memlets = {output_connector: dace.Memlet(data=field_name, subset=output_index)} + output_nodes = {field_name: field_node} + + # create a tasklet inside a parallel-map scope + self._state.add_mapped_tasklet( + "tasklet", + map_ranges, + input_memlets, + tlet_code, + output_memlets, + input_nodes=input_nodes, + output_nodes=output_nodes, + external_edges=True, + ) + + return [(field_node, self._field_type)] + + def visit_SymRef(self, node: itir.SymRef) -> str: + name = str(node.id) + assert name in set(str(p.id) for p in self._stencil.params) + return name + + +class GtirBuiltinSelect(GtirTaskletCodegen): + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + true_br_args: Sequence[GtirTaskletCodegen], + false_br_args: Sequence[GtirTaskletCodegen], + ): + super().__init__(sdfg, state) + + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: + return [] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 7d05347f5a..0b72ae1879 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Optional, Sequence +from typing import Callable, Sequence import dace @@ -21,180 +21,170 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.type_system import type_specifications as ts - -from .gtir_dataflow_context import GtirDataflowContext as DataflowContext -from .gtir_tasklet_arithmetic import GtirTaskletArithmetic -from .gtir_tasklet_codegen import ( - GtirTaskletCodegen as TaskletCodegen, - GtirTaskletSubgraph as TaskletSubgraph, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translator import ( + GtirBuiltinAsFieldOp as AsFieldOp, + GtirBuiltinSelect as Select, ) - - -REGISTERED_TASKGENS: list[type[TaskletCodegen]] = [ - GtirTaskletArithmetic, -] - - -def _make_fieldop( - ctx: DataflowContext, - tasklet_subgraph: TaskletSubgraph, - domain: Sequence[tuple[str, str, str]], -) -> None: - """Put a `TaskletSubgraph` object into a map scope. - - This helper method represents a field operator as mapped tasklet. - The map range is given by the domain. - """ - - # create ordered list of input nodes - input_arrays = [(name, ctx.sdfg.arrays[name]) for name in ctx.input_nodes] - assert len(tasklet_subgraph.input_connections) == len(input_arrays) - - # TODO: use type inference to determine the result type - type_ = ts.ScalarKind.FLOAT64 - output_arrays: list[tuple[str, dace.data.Array]] = [] - for _, field_type in tasklet_subgraph.output_connections: - if field_type is None: - field_dims = [Dimension(d) for d, _, _ in domain] - field_type = ts.FieldType(field_dims, ts.ScalarType(type_)) - field_shape = [f"{ub} - {lb}" for d, lb, ub in domain if Dimension(d) in field_type.dims] - output_name, output_array = ctx.add_local_storage(ctx.var_name(), field_type, field_shape) - output_arrays.append((output_name, output_array)) - - map_ranges = {f"i_{d}": f"{lb}:{ub}" for d, lb, ub in domain} - me, mx = ctx.state.add_map("fieldop", map_ranges) - - for (connector, field_type), (dname, _) in zip( - tasklet_subgraph.input_connections, input_arrays - ): - src_node = ctx.node_mapping[dname] - if field_type is None: - subset = ",".join([f"i_{d.value}" for d in ctx.field_types[dname].dims]) - else: - raise NotImplementedError("Array subset on tasklet connector not supported.") - ctx.state.add_memlet_path( - src_node, - me, - tasklet_subgraph.node, - dst_conn=connector, - memlet=dace.Memlet(data=dname, subset=subset), - ) - - for (connector, field_type), (dname, _) in zip( - tasklet_subgraph.output_connections, output_arrays - ): - dst_node = ctx.add_output_node(dname) - if field_type is None: - subset = ",".join([f"i_{d}" for d, _, _ in domain]) - else: - raise NotImplementedError("Array subset on tasklet connector not supported.") - ctx.state.add_memlet_path( - tasklet_subgraph.node, - mx, - dst_node, - src_conn=connector, - memlet=dace.Memlet(data=dname, subset=subset), - ) - - -def _make_fieldop_domain( - ctx: DataflowContext, node: itir.FunCall -) -> Sequence[tuple[str, str, str]]: - """Visits the domain of a field operator. - - Returns - ------- - A list of tuples(dimension_name, lower_bound_value, upper_bound_value) - """ - assert cpm.is_call_to(node, ["cartesian_domain", "unstructured_domain"]) - - domain = [] - translator = TaskletCodegen(ctx) - - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - lb = translator.visit(lower_bound) - ub = translator.visit(upper_bound) - domain.append((dimension.value, lb, ub)) - - return domain +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type +from gt4py.next.type_system import type_specifications as ts class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" - _ctx: DataflowContext + _sdfg: dace.SDFG + _head_state: dace.SDFGState + _field_types: dict[str, ts.FieldType] + _node_mapping: dict[str, dace.nodes.AccessNode] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] ): - self._ctx = DataflowContext(sdfg, state, field_types.copy()) - - def _get_tasklet_codegen(self, lambda_node: itir.Lambda) -> Optional[TaskletCodegen]: - for taskgen in REGISTERED_TASKGENS: - if taskgen.can_handle(lambda_node): - return taskgen(self._ctx) - return None + self._sdfg = sdfg + self._head_state = state + self._field_types = field_types + self._node_mapping = {} + + def _add_local_storage( + self, type_: ts.DataType, shape: list[str] + ) -> tuple[str, dace.data.Data]: + name = f"{self._head_state.label}_var" + if isinstance(type_, ts.FieldType): + dtype = as_dace_type(type_.dtype) + assert len(type_.dims) == len(shape) + # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used + name, data = self._sdfg.add_array( + name, shape, dtype, find_new_name=True, transient=True + ) + else: + assert isinstance(type_, ts.ScalarType) + dtype = as_dace_type(type_) + assert len(shape) == 0 + name, data = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + return name, data + + def _add_access_node(self, data: str) -> dace.nodes.AccessNode: + assert data in self._sdfg.arrays + if data in self._node_mapping: + node = self._node_mapping[data] + else: + node = self._head_state.add_access(data) + self._node_mapping[data] = node + return node + + def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: + domain = [] + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, itir.AxisLiteral) + dim = Dimension(axis.value) + bounds = [self.visit_symbolic(arg) for arg in named_range.args[1:3]] + domain.append((dim, bounds[0], bounds[1])) + + return domain + + def visit_expression(self, node: itir.Expr) -> list[dace.nodes.AccessNode]: + expr_builder = self.visit(node) + assert callable(expr_builder) + results = expr_builder() + expressions_nodes = [] + for node, _type in results: + assert isinstance(node, dace.nodes.AccessNode) + self._node_mapping[node.data] = node + expressions_nodes.append(node) + if isinstance(_type, ts.FieldType): + self._field_types[node.data] = _type + else: + assert isinstance(_type, ts.ScalarType) + return expressions_nodes + + def visit_symbolic(self, node: itir.Expr) -> str: + if isinstance(node, itir.Literal): + return node.value + + elif isinstance(node, itir.SymRef): + sym = str(node.id) + assert sym in self._sdfg.symbols + return sym - def visit_FunCall(self, node: itir.FunCall) -> None: - parent_ctx = self._ctx - child_ctx = parent_ctx.clone() - self._ctx = child_ctx + else: + # TODO: add support for symbolic expressions + return "1 > 2" + def visit_FunCall(self, node: itir.FunCall) -> Callable: if cpm.is_call_to(node.fun, "as_fieldop"): fun_node = node.fun assert len(fun_node.args) == 2 # expect stencil (represented as a lambda function) as first argument assert isinstance(fun_node.args[0], itir.Lambda) - taskgen = self._get_tasklet_codegen(fun_node.args[0]) - if not taskgen: - raise NotImplementedError(f"Unsupported 'as_fieldop' node ({node}).") # the domain of the field operator is passed as second argument assert isinstance(fun_node.args[1], itir.FunCall) - domain = _make_fieldop_domain(self._ctx, fun_node.args[1]) + field_domain = self.visit_domain(fun_node.args[1]) - self.visit(node.args) - tasklet_subgraph = taskgen.build_stencil(fun_node.args[0]) - _make_fieldop(self._ctx, tasklet_subgraph, domain) - else: - raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + stencil_args = [self.visit(arg) for arg in node.args] + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + + return AsFieldOp( + sdfg=self._sdfg, + state=self._head_state, + stencil=fun_node.args[0], + domain=field_domain, + args=stencil_args, + field_dtype=node_type, + ) + + elif cpm.is_call_to(node.fun, "select"): + fun_node = node.fun + assert len(fun_node.args) == 3 - assert self._ctx == child_ctx - parent_ctx.input_nodes.extend(child_ctx.output_nodes) - self._ctx = parent_ctx + # expect condition as first argument + cond = self.visit_symbolic(fun_node.args[0]) - def visit_SymRef(self, node: itir.SymRef) -> None: - dname = str(node.id) - self._ctx.add_input_node(dname) + # use join state to terminate the dataflow on a single exit node + _join_state = self._sdfg.add_state(self._head_state.label + "_join") - def write_to(self, target_expr: itir.Expr, domain_expr: itir.Expr) -> None: - """Write the current set of input nodes to external nodes. + # expect true branch as second argument + _true_state = self._sdfg.add_state(self._head_state.label + "_true_branch") + self._sdfg.add_edge(self._head_state, _true_state, dace.InterstateEdge(condition=cond)) + self._sdfg.add_edge(_true_state, _join_state, dace.InterstateEdge()) - The target arrays are supposed to be external, therefore non-transient. - """ - assert len(self._ctx.output_nodes) == 0 + # and false branch as third argument + _false_state = self._sdfg.add_state(self._head_state.label + "_false_branch") + self._sdfg.add_edge( + self._head_state, _false_state, dace.InterstateEdge(condition=f"not {cond}") + ) + self._sdfg.add_edge(_false_state, _join_state, dace.InterstateEdge()) - # TODO: add support for tuple return - assert len(self._ctx.input_nodes) == 1 - assert isinstance(target_expr, itir.SymRef) - self._ctx.add_output_node(target_expr.id) + self._head_state = _true_state + self._node_mapping = {} + true_br_args = self.visit(fun_node.args[1]) - assert isinstance(domain_expr, itir.FunCall) - domain = _make_fieldop_domain(self._ctx, domain_expr) - write_subset = ",".join(f"{lb}:{ub}" for _, lb, ub in domain) + self._head_state = _false_state + self._node_mapping = {} + false_br_args = self.visit(fun_node.args[2]) - for tasklet_node, target_node in zip(self._ctx.input_nodes, self._ctx.output_nodes): - target_array = self._ctx.sdfg.arrays[target_node] - assert not target_array.transient + self._head_state = _join_state + self._node_mapping = {} - self._ctx.state.add_nedge( - self._ctx.node_mapping[tasklet_node], - self._ctx.node_mapping[target_node], - dace.Memlet(data=target_node, subset=write_subset), + return Select( + sdfg=self._sdfg, + state=self._head_state, + true_br_args=true_br_args, + false_br_args=false_br_args, ) + + else: + raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + + def visit_SymRef(self, node: itir.SymRef) -> Callable: + name = str(node.id) + access_node = self._add_access_node(name) + assert name in self._field_types + data_type = self._field_types[name] + return lambda: [(access_node, data_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py deleted file mode 100644 index 974146f543..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_context.py +++ /dev/null @@ -1,94 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import dace - -from gt4py.next.type_system import type_specifications as ts - -from .utility import as_dace_type - - -class GtirDataflowContext: - """Defines the SDFG subgraph scope of a fieldview expression. - - This class defines a region of the dataflow which represents a fieldview expression. - It usually consists of a map scope, with a set of input nodes that traverse the entry map; - a set of transient data nodes (aka temporaries) where the output memlets traversing the - exit map will write to; and the compute nodes (tasklets) implementing the expression - within the map scope. - More than one fieldfiew region can exist within a state. In this case, the temporaies - which are written to by one fieldview region will be inputs to the next region. Also, - the set of access nodes `node_mapping` is shared among all fieldview regions within a state. - - We use this class as return type when we visit a fieldview expression. It can be extended - with all informatiion needed to construct the dataflow graph. - """ - - sdfg: dace.SDFG - state: dace.SDFGState - field_types: dict[str, ts.FieldType] - node_mapping: dict[str, dace.nodes.AccessNode] - - input_nodes: list[str] - output_nodes: list[str] - - def __init__( - self, - current_sdfg: dace.SDFG, - current_state: dace.SDFGState, - current_field_types: dict[str, ts.FieldType], - ): - self.sdfg = current_sdfg - self.state = current_state - self.field_types = current_field_types - self.node_mapping = {} - self.input_nodes = [] - self.output_nodes = [] - - def _add_node(self, data: str) -> dace.nodes.AccessNode: - assert data in self.sdfg.arrays - if data in self.node_mapping: - node = self.node_mapping[data] - else: - node = self.state.add_access(data) - self.node_mapping[data] = node - return node - - def add_input_node(self, data: str) -> dace.nodes.AccessNode: - self.input_nodes.append(data) - return self._add_node(data) - - def add_output_node(self, data: str) -> dace.nodes.AccessNode: - self.output_nodes.append(data) - return self._add_node(data) - - def add_local_storage( - self, name: str, type_: ts.FieldType, shape: list[str] - ) -> tuple[str, dace.data.Array]: - self.field_types[name] = type_ - dtype = as_dace_type(type_.dtype) - # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used - return self.sdfg.add_transient(name, shape, dtype, find_new_name=True) - - def clone(self) -> "GtirDataflowContext": - ctx = GtirDataflowContext(self.sdfg, self.state, self.field_types) - ctx.node_mapping = self.node_mapping - return ctx - - def tasklet_name(self) -> str: - return f"{self.state.label}_tasklet" - - def var_name(self) -> str: - return f"{self.state.label}_var" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py deleted file mode 100644 index 78afb68ace..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_arithmetic.py +++ /dev/null @@ -1,115 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import numpy as np - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - -from .gtir_dataflow_context import GtirDataflowContext as DataflowContext -from .gtir_tasklet_codegen import GtirTaskletCodegen - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -class GtirTaskletArithmetic(GtirTaskletCodegen): - """Translates GTIR lambda exprressions with arithmetic builtin.""" - - def __init__(self, ctx: DataflowContext): - super().__init__(ctx) - - @staticmethod - def can_handle(lambda_node: itir.Lambda) -> bool: - fun_node = lambda_node.expr - assert isinstance(fun_node, itir.FunCall) - if isinstance(fun_node.fun, itir.SymRef): - builtin_name = str(fun_node.fun.id) - return builtin_name in _MATH_BUILTINS_MAPPING - return False - - def _visit_deref(self, node: itir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - return self._visit_numeric_builtin(node) - - def visit_SymRef(self, node: itir.SymRef) -> str: - name = str(node.id) - self._input_connections.append((name, None)) - return name diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index a14660eb2c..d71ca4a0b8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -13,72 +13,141 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -from typing import Optional, final +from typing import Any, final import dace +import numpy as np from gt4py.eve import codegen from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts -from .gtir_dataflow_context import GtirDataflowContext as DataflowContext - -@dataclasses.dataclass(frozen=True) -class GtirTaskletSubgraph: - """Defines a tasklet subgraph representing a stencil expression. - - The tasklet subgraph will be used by the consumer to build a fieldview expression. - For example, it could be used in a map scope to build a fieldview expression; - or it could become the body of a scan expression. - """ - - # generic DaCe node, most often this will be a tasklet node but it could also be a nested SDFG - node: dace.nodes.Node - - # for each input/output connections, specify the field type or None if scalar - input_connections: list[tuple[str, Optional[ts.FieldType]]] - output_connections: list[tuple[str, Optional[ts.FieldType]]] +_MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +@dataclasses.dataclass +class SymbolExpr: + value: dace.symbolic.SymbolicType + dtype: dace.typeclass + + +@dataclasses.dataclass +class ValueExpr: + value: dace.nodes.AccessNode + dtype: dace.typeclass class GtirTaskletCodegen(codegen.TemplatedGenerator): - """Base class to translate GTIR to Python code to be used as tasklet body.""" + _sdfg: dace.SDFG + _state: dace.SDFGState - _ctx: DataflowContext - # list of input/output connectors and expected field type (None if scalar) - _input_connections: list[tuple[str, Optional[ts.FieldType]]] - _output_connections: list[tuple[str, Optional[ts.FieldType]]] + def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + self._sdfg = sdfg + self._state = state - def __init__(self, ctx: DataflowContext): - self._ctx = ctx - self._input_connections = [] - self._output_connections = [] + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: + """ "Creates the dataflow representing the given GTIR builtin. - @staticmethod - def can_handle(lambda_node: itir.Lambda) -> bool: + Returns a list of connections, where each connectio is defined as: + tuple(node, connector_name) + """ raise NotImplementedError @final - def build_stencil(self, node: itir.Lambda) -> GtirTaskletSubgraph: - tlet_expr = self.visit(node.expr) - - params = [str(p.id) for p in node.params] - assert len(self._input_connections) == len(params) - - outvar = "__out" - tlet_code = f"{outvar} = {tlet_expr}" - results = [outvar] - self._output_connections.append((outvar, None)) - - tlet_node: dace.tasklet = self._ctx.state.add_tasklet( - f"{self._ctx.tasklet_name()}_lambda", set(params), set(results), tlet_code - ) - - subgraph = GtirTaskletSubgraph(tlet_node, self._input_connections, self._output_connections) - - return subgraph + def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode: + name = f"{self._state.label}_var" + if isinstance(data_type, ts.FieldType): + assert len(data_type.dims) == len(shape) + dtype = as_dace_type(data_type.dtype) + return self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) + else: + assert isinstance(data_type, ts.ScalarType) + assert len(shape) == 0 + dtype = as_dace_type(data_type) + return self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") @final - def visit_Lambda(self, node: itir.Lambda) -> GtirTaskletSubgraph: - # This visitor class should never encounter `itir.Lambda` expressionsß + def visit_Lambda(self, node: itir.Lambda) -> Any: + # This visitor class should never encounter `itir.Lambda` expressions raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") + + def visit_SymRef(self, node: itir.SymRef) -> str: + raise NotImplementedError diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 9281864cc0..ee377c4137 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -161,9 +161,28 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to the external storage. """ - fieldview_builder = DataflowBuilder(sdfg, state, self._field_types) - fieldview_builder.visit(stmt.expr) + dataflow_builder = DataflowBuilder(sdfg, state, self._field_types) + expr_nodes = dataflow_builder.visit_expression(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - fieldview_builder.write_to(stmt.target, stmt.domain) + target_nodes = dataflow_builder.visit_expression(stmt.target) + assert len(expr_nodes) == len(target_nodes) + + domain = dataflow_builder.visit_domain(stmt.domain) + # convert domain to dictionary to ease access to dimension boundaries + domain_map = {dim: (lb, ub) for dim, lb, ub in domain} + + for expr_node, target_node in zip(expr_nodes, target_nodes): + target_array = sdfg.arrays[target_node.data] + assert not target_array.transient + + subset = ",".join( + f"{domain_map[dim][0]}:{domain_map[dim][1]}" + for dim in self._field_types[target_node.data].dims + ) + state.add_nedge( + expr_node, + target_node, + dace.Memlet(data=target_node.data, subset=subset), + ) From 390f3b45773dd7c46367c1f8d50a2f783b5fc89b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 29 Apr 2024 22:13:01 +0200 Subject: [PATCH 017/235] Add draft of select operator --- src/gt4py/next/iterator/ir.py | 1 + .../dace_fieldview/gtir_builtin_translator.py | 111 ++++++++++++++---- .../dace_fieldview/gtir_dataflow_builder.py | 55 +++++---- .../dace_fieldview/gtir_tasklet_codegen.py | 14 ++- .../runners/dace_fieldview/gtir_to_sdfg.py | 19 +-- .../runners_tests/test_dace_fieldview.py | 66 +++++++++++ 6 files changed, 202 insertions(+), 64 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 538ac84cb8..aeb8be6f0c 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -208,6 +208,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib GTIR_BUILTINS = { *BUILTINS, "as_fieldop", # `as_fieldop(stencil)` creates field_operator from stencil + "select", # `select(cond, field_a, field_b)` creates the field on one branch or the other } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py index 4732472e0f..0bf9cf9e9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py @@ -22,9 +22,37 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts +class GtirBuiltinScalarAccess(GtirTaskletCodegen): + _sym_name: str + _data_type: ts.ScalarType + + def __init__( + self, sdfg: dace.SDFG, state: dace.SDFGState, sym_name: str, data_type: ts.ScalarType + ): + super().__init__(sdfg, state) + self._sym_name = sym_name + self._data_type = data_type + + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + tasklet_node = self._state.add_tasklet( + f"get_{self._sym_name}", + {}, + {"__out"}, + f"__out = {self._sym_name}", + ) + name = f"{self._state.label}_var" + dtype = as_dace_type(self._data_type) + output_node = self._state.add_scalar(name, dtype, find_new_name=True, transient=True) + self._state.add_edge( + tasklet_node, "__out", output_node, None, dace.Memlet(data=output_node.data, subset="0") + ) + return [(output_node, self._data_type)] + + class GtirBuiltinAsFieldOp(GtirTaskletCodegen): _stencil: itir.Lambda _domain: dict[Dimension, tuple[str, str]] @@ -46,7 +74,7 @@ def __init__( self._domain = {dim: (lb, ub) for dim, lb, ub in domain} self._field_type = ts.FieldType([dim for dim, _, _ in domain], field_dtype) - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # generate the python code for this stencil output_connector = "__out" tlet_code = "{var} = {code}".format( @@ -59,8 +87,7 @@ def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: f"{self._domain[dim][1]} - {self._domain[dim][0]}" for dim in self._field_type.dims ] - field_name, _ = self._add_local_storage(self._field_type, field_shape) - field_node = self._state.add_access(field_name) + field_node = self._add_local_storage(self._field_type, field_shape) # create map range corresponding to the field operator domain map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self._domain.items()} @@ -77,20 +104,25 @@ def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: # require (for now) all input nodes to be data access nodes assert isinstance(arg_node, dace.nodes.AccessNode) input_nodes[arg_node.data] = arg_node - # support either single element access (general case) or full array shape - is_scalar = all(dim in self._domain for dim in arg_type.dims) - if is_scalar: - subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) - input_memlets[connector] = dace.Memlet(data=arg_node.data, subset=subset) + if isinstance(arg_type, ts.FieldType): + # support either single element access (general case) or full array shape + is_scalar = all(dim in self._domain for dim in arg_type.dims) + if is_scalar: + subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) + input_memlets[connector] = dace.Memlet( + data=arg_node.data, subset=subset, volume=1 + ) + else: + memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self._sdfg)) + memlet.volume = 1 + input_memlets[connector] = memlet else: - input_memlets[connector] = dace.Memlet.from_array( - arg_node.data, arg_node.desc(self._sdfg) - ) + input_memlets[connector] = dace.Memlet(data=arg_node.data, subset="0") # assume tasklet with single output output_index = ",".join(f"i_{dim.value}" for dim in self._field_type.dims) - output_memlets = {output_connector: dace.Memlet(data=field_name, subset=output_index)} - output_nodes = {field_name: field_node} + output_memlets = {output_connector: dace.Memlet(data=field_node.data, subset=output_index)} + output_nodes = {field_node.data: field_node} # create a tasklet inside a parallel-map scope self._state.add_mapped_tasklet( @@ -106,21 +138,54 @@ def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: return [(field_node, self._field_type)] - def visit_SymRef(self, node: itir.SymRef) -> str: - name = str(node.id) - assert name in set(str(p.id) for p in self._stencil.params) - return name - class GtirBuiltinSelect(GtirTaskletCodegen): + _true_br_builder: GtirTaskletCodegen + _false_br_builder: GtirTaskletCodegen + def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - true_br_args: Sequence[GtirTaskletCodegen], - false_br_args: Sequence[GtirTaskletCodegen], + true_br_builder: GtirTaskletCodegen, + false_br_builder: GtirTaskletCodegen, ): super().__init__(sdfg, state) - - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: - return [] + self._true_br_builder = true_br_builder + self._false_br_builder = false_br_builder + + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + true_br_args = self._true_br_builder() + false_br_args = self._false_br_builder() + assert len(true_br_args) == len(false_br_args) + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, false_br_type = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + assert true_br_type == false_br_type + array_type = self._sdfg.arrays[true_br_node.data] + access_node = self._add_local_storage(true_br_type, array_type.shape) + output_nodes.append((access_node, true_br_type)) + + data_name = access_node.data + true_br_output_node = self._true_br_builder._state.add_access(data_name) + self._true_br_builder._state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array( + true_br_output_node.data, true_br_output_node.desc(self._sdfg) + ), + ) + + false_br_output_node = self._false_br_builder._state.add_access(data_name) + self._false_br_builder._state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array( + false_br_output_node.data, false_br_output_node.desc(self._sdfg) + ), + ) + return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 0b72ae1879..7f18de1885 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -23,8 +23,12 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translator import ( GtirBuiltinAsFieldOp as AsFieldOp, + GtirBuiltinScalarAccess as ScalarAccess, GtirBuiltinSelect as Select, ) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( + GtirTaskletCodegen, +) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts @@ -34,15 +38,18 @@ class GtirDataflowBuilder(eve.NodeVisitor): _sdfg: dace.SDFG _head_state: dace.SDFGState - _field_types: dict[str, ts.FieldType] + _data_types: dict[str, ts.FieldType | ts.ScalarType] _node_mapping: dict[str, dace.nodes.AccessNode] def __init__( - self, sdfg: dace.SDFG, state: dace.SDFGState, field_types: dict[str, ts.FieldType] + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + data_types: dict[str, ts.FieldType | ts.ScalarType], ): self._sdfg = sdfg self._head_state = state - self._field_types = field_types + self._data_types = data_types self._node_mapping = {} def _add_local_storage( @@ -91,28 +98,15 @@ def visit_expression(self, node: itir.Expr) -> list[dace.nodes.AccessNode]: assert callable(expr_builder) results = expr_builder() expressions_nodes = [] - for node, _type in results: + for node, _ in results: assert isinstance(node, dace.nodes.AccessNode) - self._node_mapping[node.data] = node expressions_nodes.append(node) - if isinstance(_type, ts.FieldType): - self._field_types[node.data] = _type - else: - assert isinstance(_type, ts.ScalarType) + return expressions_nodes def visit_symbolic(self, node: itir.Expr) -> str: - if isinstance(node, itir.Literal): - return node.value - - elif isinstance(node, itir.SymRef): - sym = str(node.id) - assert sym in self._sdfg.symbols - return sym - - else: - # TODO: add support for symbolic expressions - return "1 > 2" + codegen = GtirTaskletCodegen(self._sdfg, self._head_state) + return codegen.visit(node) def visit_FunCall(self, node: itir.FunCall) -> Callable: if cpm.is_call_to(node.fun, "as_fieldop"): @@ -163,11 +157,11 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: self._head_state = _true_state self._node_mapping = {} - true_br_args = self.visit(fun_node.args[1]) + true_br_callable = self.visit(fun_node.args[1]) self._head_state = _false_state self._node_mapping = {} - false_br_args = self.visit(fun_node.args[2]) + false_br_callable = self.visit(fun_node.args[2]) self._head_state = _join_state self._node_mapping = {} @@ -175,8 +169,8 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: return Select( sdfg=self._sdfg, state=self._head_state, - true_br_args=true_br_args, - false_br_args=false_br_args, + true_br_builder=true_br_callable, + false_br_builder=false_br_callable, ) else: @@ -184,7 +178,12 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: def visit_SymRef(self, node: itir.SymRef) -> Callable: name = str(node.id) - access_node = self._add_access_node(name) - assert name in self._field_types - data_type = self._field_types[name] - return lambda: [(access_node, data_type)] + assert name in self._data_types + data_type = self._data_types[name] + if isinstance(data_type, ts.FieldType): + access_node = self._add_access_node(name) + return lambda: [(access_node, data_type)] + else: + # scalar symbols are passed to the SDFG as symbols + assert name in self._sdfg.symbols + return ScalarAccess(self._sdfg, self._head_state, name, data_type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index d71ca4a0b8..b0c17d8c25 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -100,7 +100,7 @@ def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: self._sdfg = sdfg self._state = state - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType]]: + def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: """ "Creates the dataflow representing the given GTIR builtin. Returns a list of connections, where each connectio is defined as: @@ -114,12 +114,13 @@ def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.n if isinstance(data_type, ts.FieldType): assert len(data_type.dims) == len(shape) dtype = as_dace_type(data_type.dtype) - return self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) + name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) else: assert isinstance(data_type, ts.ScalarType) assert len(shape) == 0 dtype = as_dace_type(data_type) - return self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + return self._state.add_access(name) def _visit_deref(self, node: itir.FunCall) -> str: assert len(node.args) == 1 @@ -149,5 +150,10 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: # This visitor class should never encounter `itir.Lambda` expressions raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") + @final + def visit_Literal(self, node: itir.Literal) -> str: + return node.value + + @final def visit_SymRef(self, node: itir.SymRef) -> str: - raise NotImplementedError + return str(node.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index ee377c4137..72ce442553 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -47,7 +47,7 @@ class GtirToSDFG(eve.NodeVisitor): """ _param_types: list[ts.TypeSpec] - _field_types: dict[str, ts.FieldType] + _data_types: dict[str, ts.FieldType | ts.ScalarType] _offset_providers: Mapping[str, Any] def __init__( @@ -56,7 +56,7 @@ def __init__( offset_providers: dict[str, Connectivity | Dimension], ): self._param_types = param_types - self._field_types = {} + self._data_types = {} self._offset_providers = offset_providers def _make_array_shape_and_strides( @@ -89,16 +89,16 @@ def _make_array_shape_and_strides( return shape, strides def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: + assert isinstance(type_, (ts.FieldType, ts.ScalarType)) + self._data_types[name] = type_ + if isinstance(type_, ts.FieldType): dtype = as_dace_type(type_.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. sym_shape, sym_strides = self._make_array_shape_and_strides(name, type_.dims) sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) - self._field_types[name] = type_ - else: - assert isinstance(type_, ts.ScalarType) dtype = as_dace_type(type_) # scalar arguments passed to the program are represented as symbols in DaCe SDFG sdfg.add_symbol(name, dtype) @@ -161,7 +161,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to the external storage. """ - dataflow_builder = DataflowBuilder(sdfg, state, self._field_types) + dataflow_builder = DataflowBuilder(sdfg, state, self._data_types) expr_nodes = dataflow_builder.visit_expression(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression @@ -176,12 +176,13 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) for expr_node, target_node in zip(expr_nodes, target_nodes): target_array = sdfg.arrays[target_node.data] assert not target_array.transient + target_field_type = self._data_types[target_node.data] + assert isinstance(target_field_type, ts.FieldType) subset = ",".join( - f"{domain_map[dim][0]}:{domain_map[dim][1]}" - for dim in self._field_types[target_node.data].dims + f"{domain_map[dim][0]}:{domain_map[dim][1]}" for dim in target_field_type.dims ) - state.add_nedge( + dataflow_builder._head_state.add_nedge( expr_node, target_node, dace.Memlet(data=target_node.data, subset=subset), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 03968857f5..cf78dd3f3f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -139,3 +139,69 @@ def test_gtir_sum3(): sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + c)) + + +def test_gtir_select(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + ) + testee = itir.Program( + id="select_2sums", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="y"), + itir.Sym(id="w"), + itir.Sym(id="z"), + itir.Sym(id="cond"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("select")( + im.deref("cond"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", "y"), + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("y", "w"), + ) + )(), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + c = np.random.rand(N) + d = np.empty_like(a) + + sdfg_genenerator = FieldviewGtirToSDFG( + [ + FTYPE, + FTYPE, + FTYPE, + FTYPE, + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.INT32), + ], + OFFSET_PROVIDERS, + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + for s in [False, True]: + sdfg(cond=s, x=a, y=b, w=c, z=d, **FSYMBOLS) + assert np.allclose(d, (a + b) if s else (b + c)) From de27419f2e217382e2f6b175026331b6d2b27c2d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 09:16:32 +0200 Subject: [PATCH 018/235] Remove node mapping --- ...ator.py => gtir_builtin_field_operator.py} | 82 +------------------ .../dace_fieldview/gtir_builtin_select.py | 73 +++++++++++++++++ .../dace_fieldview/gtir_builtin_symbol_ref.py | 60 ++++++++++++++ .../dace_fieldview/gtir_dataflow_builder.py | 38 +++------ .../dace_fieldview/gtir_tasklet_codegen.py | 11 ++- .../runners/dace_fieldview/gtir_to_sdfg.py | 12 ++- 6 files changed, 163 insertions(+), 113 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtin_translator.py => gtir_builtin_field_operator.py} (57%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py similarity index 57% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index 0bf9cf9e9a..2987092933 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -22,37 +22,9 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinScalarAccess(GtirTaskletCodegen): - _sym_name: str - _data_type: ts.ScalarType - - def __init__( - self, sdfg: dace.SDFG, state: dace.SDFGState, sym_name: str, data_type: ts.ScalarType - ): - super().__init__(sdfg, state) - self._sym_name = sym_name - self._data_type = data_type - - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - tasklet_node = self._state.add_tasklet( - f"get_{self._sym_name}", - {}, - {"__out"}, - f"__out = {self._sym_name}", - ) - name = f"{self._state.label}_var" - dtype = as_dace_type(self._data_type) - output_node = self._state.add_scalar(name, dtype, find_new_name=True, transient=True) - self._state.add_edge( - tasklet_node, "__out", output_node, None, dace.Memlet(data=output_node.data, subset="0") - ) - return [(output_node, self._data_type)] - - class GtirBuiltinAsFieldOp(GtirTaskletCodegen): _stencil: itir.Lambda _domain: dict[Dimension, tuple[str, str]] @@ -74,7 +46,7 @@ def __init__( self._domain = {dim: (lb, ub) for dim, lb, ub in domain} self._field_type = ts.FieldType([dim for dim, _, _ in domain], field_dtype) - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # generate the python code for this stencil output_connector = "__out" tlet_code = "{var} = {code}".format( @@ -137,55 +109,3 @@ def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]] ) return [(field_node, self._field_type)] - - -class GtirBuiltinSelect(GtirTaskletCodegen): - _true_br_builder: GtirTaskletCodegen - _false_br_builder: GtirTaskletCodegen - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - true_br_builder: GtirTaskletCodegen, - false_br_builder: GtirTaskletCodegen, - ): - super().__init__(sdfg, state) - self._true_br_builder = true_br_builder - self._false_br_builder = false_br_builder - - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - true_br_args = self._true_br_builder() - false_br_args = self._false_br_builder() - assert len(true_br_args) == len(false_br_args) - - output_nodes = [] - for true_br, false_br in zip(true_br_args, false_br_args): - true_br_node, true_br_type = true_br - assert isinstance(true_br_node, dace.nodes.AccessNode) - false_br_node, false_br_type = false_br - assert isinstance(false_br_node, dace.nodes.AccessNode) - assert true_br_type == false_br_type - array_type = self._sdfg.arrays[true_br_node.data] - access_node = self._add_local_storage(true_br_type, array_type.shape) - output_nodes.append((access_node, true_br_type)) - - data_name = access_node.data - true_br_output_node = self._true_br_builder._state.add_access(data_name) - self._true_br_builder._state.add_nedge( - true_br_node, - true_br_output_node, - dace.Memlet.from_array( - true_br_output_node.data, true_br_output_node.desc(self._sdfg) - ), - ) - - false_br_output_node = self._false_br_builder._state.add_access(data_name) - self._false_br_builder._state.add_nedge( - false_br_node, - false_br_output_node, - dace.Memlet.from_array( - false_br_output_node.data, false_br_output_node.desc(self._sdfg) - ), - ) - return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py new file mode 100644 index 0000000000..51348fea40 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import dace + +from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( + GtirTaskletCodegen, +) +from gt4py.next.type_system import type_specifications as ts + + +class GtirBuiltinSelect(GtirTaskletCodegen): + _true_br_builder: GtirTaskletCodegen + _false_br_builder: GtirTaskletCodegen + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + true_br_builder: GtirTaskletCodegen, + false_br_builder: GtirTaskletCodegen, + ): + super().__init__(sdfg, state) + self._true_br_builder = true_br_builder + self._false_br_builder = false_br_builder + + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + true_br_args = self._true_br_builder() + false_br_args = self._false_br_builder() + assert len(true_br_args) == len(false_br_args) + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, false_br_type = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + assert true_br_type == false_br_type + array_type = self._sdfg.arrays[true_br_node.data] + access_node = self._add_local_storage(true_br_type, array_type.shape) + output_nodes.append((access_node, true_br_type)) + + data_name = access_node.data + true_br_output_node = self._true_br_builder._state.add_access(data_name) + self._true_br_builder._state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array( + true_br_output_node.data, true_br_output_node.desc(self._sdfg) + ), + ) + + false_br_output_node = self._false_br_builder._state.add_access(data_name) + self._false_br_builder._state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array( + false_br_output_node.data, false_br_output_node.desc(self._sdfg) + ), + ) + return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py new file mode 100644 index 0000000000..ff1167ae14 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py @@ -0,0 +1,60 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import dace + +from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( + GtirTaskletCodegen, +) +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type +from gt4py.next.type_system import type_specifications as ts + + +class GtirBuiltinSymbolRef(GtirTaskletCodegen): + _sym_name: str + _sym_type: ts.FieldType | ts.ScalarType + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + sym_name: str, + data_type: ts.FieldType | ts.ScalarType, + ): + super().__init__(sdfg, state) + self._sym_name = sym_name + self._sym_type = data_type + + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + if isinstance(self._sym_type, ts.FieldType): + sym_node = self._state.add_access(self._sym_name) + + else: + # scalar symbols are passed to the SDFG as symbols + assert self._sym_name in self._sdfg.symbols + tasklet_node = self._state.add_tasklet( + f"get_{self._sym_name}", + {}, + {"__out"}, + f"__out = {self._sym_name}", + ) + name = f"{self._state.label}_var" + dtype = as_dace_type(self._sym_type) + sym_node = self._state.add_scalar(name, dtype, find_new_name=True, transient=True) + self._state.add_edge( + tasklet_node, "__out", sym_node, None, dace.Memlet(data=sym_node.data, subset="0") + ) + + return [(sym_node, self._sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 7f18de1885..674a091dfa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -21,11 +21,15 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translator import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_field_operator import ( GtirBuiltinAsFieldOp as AsFieldOp, - GtirBuiltinScalarAccess as ScalarAccess, +) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_select import ( GtirBuiltinSelect as Select, ) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_symbol_ref import ( + GtirBuiltinSymbolRef as SymbolRef, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) @@ -39,7 +43,6 @@ class GtirDataflowBuilder(eve.NodeVisitor): _sdfg: dace.SDFG _head_state: dace.SDFGState _data_types: dict[str, ts.FieldType | ts.ScalarType] - _node_mapping: dict[str, dace.nodes.AccessNode] def __init__( self, @@ -50,7 +53,6 @@ def __init__( self._sdfg = sdfg self._head_state = state self._data_types = data_types - self._node_mapping = {} def _add_local_storage( self, type_: ts.DataType, shape: list[str] @@ -65,20 +67,11 @@ def _add_local_storage( ) else: assert isinstance(type_, ts.ScalarType) - dtype = as_dace_type(type_) assert len(shape) == 0 + dtype = as_dace_type(type_) name, data = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) return name, data - def _add_access_node(self, data: str) -> dace.nodes.AccessNode: - assert data in self._sdfg.arrays - if data in self._node_mapping: - node = self._node_mapping[data] - else: - node = self._head_state.add_access(data) - self._node_mapping[data] = node - return node - def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: domain = [] assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) @@ -156,15 +149,12 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: self._sdfg.add_edge(_false_state, _join_state, dace.InterstateEdge()) self._head_state = _true_state - self._node_mapping = {} true_br_callable = self.visit(fun_node.args[1]) self._head_state = _false_state - self._node_mapping = {} false_br_callable = self.visit(fun_node.args[2]) self._head_state = _join_state - self._node_mapping = {} return Select( sdfg=self._sdfg, @@ -177,13 +167,7 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") def visit_SymRef(self, node: itir.SymRef) -> Callable: - name = str(node.id) - assert name in self._data_types - data_type = self._data_types[name] - if isinstance(data_type, ts.FieldType): - access_node = self._add_access_node(name) - return lambda: [(access_node, data_type)] - else: - # scalar symbols are passed to the SDFG as symbols - assert name in self._sdfg.symbols - return ScalarAccess(self._sdfg, self._head_state, name, data_type) + arg_name = str(node.id) + assert arg_name in self._data_types + arg_type = self._data_types[arg_name] + return SymbolRef(self._sdfg, self._head_state, arg_name, arg_type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index b0c17d8c25..46c6bdb46b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -95,18 +95,24 @@ class ValueExpr: class GtirTaskletCodegen(codegen.TemplatedGenerator): _sdfg: dace.SDFG _state: dace.SDFGState + _nodes: list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]] def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: self._sdfg = sdfg self._state = state + self._nodes = [] + @final def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: """ "Creates the dataflow representing the given GTIR builtin. Returns a list of connections, where each connectio is defined as: tuple(node, connector_name) """ - raise NotImplementedError + if not self._nodes: + self._nodes = self._build() + assert self._nodes + return self._nodes @final def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode: @@ -122,6 +128,9 @@ def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.n name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) return self._state.add_access(name) + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + raise NotImplementedError + def _visit_deref(self, node: itir.FunCall) -> str: assert len(node.args) == 1 if isinstance(node.args[0], itir.SymRef): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 72ce442553..e860837f33 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -177,11 +177,15 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) target_array = sdfg.arrays[target_node.data] assert not target_array.transient target_field_type = self._data_types[target_node.data] - assert isinstance(target_field_type, ts.FieldType) - subset = ",".join( - f"{domain_map[dim][0]}:{domain_map[dim][1]}" for dim in target_field_type.dims - ) + if isinstance(target_field_type, ts.FieldType): + subset = ",".join( + f"{domain_map[dim][0]}:{domain_map[dim][1]}" for dim in target_field_type.dims + ) + else: + assert len(domain) == 0 + subset = "0" + dataflow_builder._head_state.add_nedge( expr_node, target_node, From cd900f589edf0940ab048bcc729516811092b5c3 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 10:00:00 +0200 Subject: [PATCH 019/235] Remove node mapping (fix + test case) --- .../dace_fieldview/gtir_builtin_symbol_ref.py | 20 +++++++++- .../dace_fieldview/gtir_tasklet_codegen.py | 7 +--- .../runners_tests/test_dace_fieldview.py | 37 +++++++++++++++++++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py index ff1167ae14..e1ac0cca3b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py @@ -13,6 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Optional + import dace from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( @@ -37,8 +39,24 @@ def __init__( self._sym_name = sym_name self._sym_type = data_type + def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: + access_nodes = [ + node + for node in self._state.nodes() + if isinstance(node, dace.nodes.AccessNode) and node.data == self._sym_name + ] + if len(access_nodes) == 0: + return None + assert len(access_nodes) == 1 + return access_nodes[0] + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - if isinstance(self._sym_type, ts.FieldType): + sym_node = self._get_access_node() + if sym_node: + # share access node in same state + pass + + elif isinstance(self._sym_type, ts.FieldType): sym_node = self._state.add_access(self._sym_name) else: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 46c6bdb46b..700a2c3382 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -95,12 +95,10 @@ class ValueExpr: class GtirTaskletCodegen(codegen.TemplatedGenerator): _sdfg: dace.SDFG _state: dace.SDFGState - _nodes: list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]] def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: self._sdfg = sdfg self._state = state - self._nodes = [] @final def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: @@ -109,10 +107,7 @@ def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]] Returns a list of connections, where each connectio is defined as: tuple(node, connector_name) """ - if not self._nodes: - self._nodes = self._build() - assert self._nodes - return self._nodes + return self._build() @final def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index cf78dd3f3f..7cc37102c9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -88,6 +88,43 @@ def test_gtir_sum2(): assert np.allclose(c, (a + b)) +def test_gtir_sum2_sym(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + ) + testee = itir.Program( + id="sum_2fields", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="z"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", "x"), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg_genenerator = FieldviewGtirToSDFG( + [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + sdfg(x=a, z=b, **FSYMBOLS) + assert np.allclose(b, (a + a)) + + def test_gtir_sum3(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") From 326cbb5ff878eef186017e66dabcfa3d11d302c0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 10:04:22 +0200 Subject: [PATCH 020/235] Add test case for inlined mathematic builtins --- .../runners_tests/test_dace_fieldview.py | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 7cc37102c9..ebe5366ccb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -129,7 +129,7 @@ def test_gtir_sum3(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") ) - testee = itir.Program( + testee_fieldview = itir.Program( id="sum_3fields", function_definitions=[], params=[ @@ -161,6 +161,32 @@ def test_gtir_sum3(): ) ], ) + testee_inlined = itir.Program( + id="sum_3fields", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="y"), + itir.Sym(id="w"), + itir.Sym(id="z"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "b", "c")( + im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) + ), + domain, + ) + )("x", "y", "w"), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) a = np.random.rand(N) b = np.random.rand(N) @@ -170,12 +196,13 @@ def test_gtir_sum3(): sdfg_genenerator = FieldviewGtirToSDFG( [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + for testee in [testee_fieldview, testee_inlined]: + sdfg = sdfg_genenerator.visit(testee) + assert isinstance(sdfg, dace.SDFG) - sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) - assert np.allclose(d, (a + b + c)) + sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) + assert np.allclose(d, (a + b + c)) def test_gtir_select(): From a10b614390a934f9528881813cddf9adc3694b55 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 11:07:44 +0200 Subject: [PATCH 021/235] Go full functional (remove SDFGState member var) --- .../gtir_builtin_field_operator.py | 2 +- .../dace_fieldview/gtir_builtin_select.py | 4 +- .../dace_fieldview/gtir_dataflow_builder.py | 64 +++++++++---------- .../dace_fieldview/gtir_tasklet_codegen.py | 6 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 9 +-- 5 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index 2987092933..27d01c0800 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -69,7 +69,7 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_memlets: dict[str, dace.Memlet] = {} assert len(self._args) == len(self._stencil.params) for arg, param in zip(self._args, self._stencil.params): - arg_nodes = arg() + arg_nodes, _ = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] connector = str(param.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py index 51348fea40..b9d46ec856 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py @@ -37,8 +37,8 @@ def __init__( self._false_br_builder = false_br_builder def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - true_br_args = self._true_br_builder() - false_br_args = self._false_br_builder() + true_br_args, _ = self._true_br_builder() + false_br_args, _ = self._false_br_builder() assert len(true_br_args) == len(false_br_args) output_nodes = [] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 674a091dfa..fbee4a5da9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -37,27 +37,31 @@ from gt4py.next.type_system import type_specifications as ts +def unique_name(prefix: str) -> str: + unique_id = getattr(unique_name, "_unique_id", 0) # static variable + setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] + + return f"{prefix}_{unique_id}" + + class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" _sdfg: dace.SDFG - _head_state: dace.SDFGState _data_types: dict[str, ts.FieldType | ts.ScalarType] def __init__( self, sdfg: dace.SDFG, - state: dace.SDFGState, data_types: dict[str, ts.FieldType | ts.ScalarType], ): self._sdfg = sdfg - self._head_state = state self._data_types = data_types def _add_local_storage( self, type_: ts.DataType, shape: list[str] ) -> tuple[str, dace.data.Data]: - name = f"{self._head_state.label}_var" + name = unique_name("var") if isinstance(type_, ts.FieldType): dtype = as_dace_type(type_.dtype) assert len(type_.dims) == len(shape) @@ -86,22 +90,26 @@ def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: return domain - def visit_expression(self, node: itir.Expr) -> list[dace.nodes.AccessNode]: - expr_builder = self.visit(node) + def visit_expression( + self, node: itir.Expr, head_state: dace.SDFGState + ) -> tuple[list[dace.nodes.AccessNode], dace.SDFGState]: + expr_builder = self.visit(node, state=head_state) assert callable(expr_builder) - results = expr_builder() + results, head_state = expr_builder() + expressions_nodes = [] for node, _ in results: assert isinstance(node, dace.nodes.AccessNode) expressions_nodes.append(node) - return expressions_nodes + return expressions_nodes, head_state def visit_symbolic(self, node: itir.Expr) -> str: - codegen = GtirTaskletCodegen(self._sdfg, self._head_state) + state = self._sdfg.start_state + codegen = GtirTaskletCodegen(self._sdfg, state) return codegen.visit(node) - def visit_FunCall(self, node: itir.FunCall) -> Callable: + def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: if cpm.is_call_to(node.fun, "as_fieldop"): fun_node = node.fun assert len(fun_node.args) == 2 @@ -111,7 +119,7 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: assert isinstance(fun_node.args[1], itir.FunCall) field_domain = self.visit_domain(fun_node.args[1]) - stencil_args = [self.visit(arg) for arg in node.args] + stencil_args = [self.visit(arg, state=state) for arg in node.args] # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type @@ -119,7 +127,7 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: return AsFieldOp( sdfg=self._sdfg, - state=self._head_state, + state=state, stencil=fun_node.args[0], domain=field_domain, args=stencil_args, @@ -134,31 +142,23 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: cond = self.visit_symbolic(fun_node.args[0]) # use join state to terminate the dataflow on a single exit node - _join_state = self._sdfg.add_state(self._head_state.label + "_join") + join_state = self._sdfg.add_state(state.label + "_join") # expect true branch as second argument - _true_state = self._sdfg.add_state(self._head_state.label + "_true_branch") - self._sdfg.add_edge(self._head_state, _true_state, dace.InterstateEdge(condition=cond)) - self._sdfg.add_edge(_true_state, _join_state, dace.InterstateEdge()) + true_state = self._sdfg.add_state(state.label + "_true_branch") + self._sdfg.add_edge(state, true_state, dace.InterstateEdge(condition=cond)) + self._sdfg.add_edge(true_state, join_state, dace.InterstateEdge()) + true_br_callable = self.visit(fun_node.args[1], state=true_state) # and false branch as third argument - _false_state = self._sdfg.add_state(self._head_state.label + "_false_branch") - self._sdfg.add_edge( - self._head_state, _false_state, dace.InterstateEdge(condition=f"not {cond}") - ) - self._sdfg.add_edge(_false_state, _join_state, dace.InterstateEdge()) - - self._head_state = _true_state - true_br_callable = self.visit(fun_node.args[1]) - - self._head_state = _false_state - false_br_callable = self.visit(fun_node.args[2]) - - self._head_state = _join_state + false_state = self._sdfg.add_state(state.label + "_false_branch") + self._sdfg.add_edge(state, false_state, dace.InterstateEdge(condition=f"not {cond}")) + self._sdfg.add_edge(false_state, join_state, dace.InterstateEdge()) + false_br_callable = self.visit(fun_node.args[2], state=false_state) return Select( sdfg=self._sdfg, - state=self._head_state, + state=join_state, true_br_builder=true_br_callable, false_br_builder=false_br_callable, ) @@ -166,8 +166,8 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - def visit_SymRef(self, node: itir.SymRef) -> Callable: + def visit_SymRef(self, node: itir.SymRef, state: dace.SDFGState) -> Callable: arg_name = str(node.id) assert arg_name in self._data_types arg_type = self._data_types[arg_name] - return SymbolRef(self._sdfg, self._head_state, arg_name, arg_type) + return SymbolRef(self._sdfg, state, arg_name, arg_type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 700a2c3382..77d649cea4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -101,13 +101,15 @@ def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: self._state = state @final - def __call__(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def __call__( + self, + ) -> tuple[list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]], dace.SDFGState]: """ "Creates the dataflow representing the given GTIR builtin. Returns a list of connections, where each connectio is defined as: tuple(node, connector_name) """ - return self._build() + return self._build(), self._state @final def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index e860837f33..0969e94bbd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -161,12 +161,13 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to the external storage. """ - dataflow_builder = DataflowBuilder(sdfg, state, self._data_types) - expr_nodes = dataflow_builder.visit_expression(stmt.expr) + dataflow_builder = DataflowBuilder(sdfg, self._data_types) + expr_nodes, state = dataflow_builder.visit_expression(stmt.expr, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - target_nodes = dataflow_builder.visit_expression(stmt.target) + target_builder = DataflowBuilder(sdfg, self._data_types) + target_nodes, state = target_builder.visit_expression(stmt.target, state) assert len(expr_nodes) == len(target_nodes) domain = dataflow_builder.visit_domain(stmt.domain) @@ -186,7 +187,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) assert len(domain) == 0 subset = "0" - dataflow_builder._head_state.add_nedge( + state.add_nedge( expr_node, target_node, dace.Memlet(data=target_node.data, subset=subset), From aef42653c44153e5a4d2792e29555b0e59aa59ed Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 11:35:01 +0200 Subject: [PATCH 022/235] Minor edit --- .../dace_fieldview/gtir_dataflow_builder.py | 26 ------------------- .../dace_fieldview/gtir_tasklet_codegen.py | 9 ++++--- .../runners/dace_fieldview/utility.py | 7 +++++ 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index fbee4a5da9..4f97615768 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -33,17 +33,9 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts -def unique_name(prefix: str) -> str: - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" @@ -58,24 +50,6 @@ def __init__( self._sdfg = sdfg self._data_types = data_types - def _add_local_storage( - self, type_: ts.DataType, shape: list[str] - ) -> tuple[str, dace.data.Data]: - name = unique_name("var") - if isinstance(type_, ts.FieldType): - dtype = as_dace_type(type_.dtype) - assert len(type_.dims) == len(shape) - # TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used - name, data = self._sdfg.add_array( - name, shape, dtype, find_new_name=True, transient=True - ) - else: - assert isinstance(type_, ts.ScalarType) - assert len(shape) == 0 - dtype = as_dace_type(type_) - name, data = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) - return name, data - def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: domain = [] assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 77d649cea4..7e03bb9cd0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -21,7 +21,7 @@ from gt4py.eve import codegen from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts @@ -112,14 +112,15 @@ def __call__( return self._build(), self._state @final - def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode: - name = f"{self._state.label}_var" + def _add_local_storage( + self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + ) -> dace.nodes.AccessNode: + name = unique_name("var") if isinstance(data_type, ts.FieldType): assert len(data_type.dims) == len(shape) dtype = as_dace_type(data_type.dtype) name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) else: - assert isinstance(data_type, ts.ScalarType) assert len(shape) == 0 dtype = as_dace_type(data_type) name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index b25521f5c3..2fcc12bc85 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -40,3 +40,10 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne for offset, table in offset_provider.items() if isinstance(table, Connectivity) } + + +def unique_name(prefix: str) -> str: + unique_id = getattr(unique_name, "_unique_id", 0) # static variable + setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] + + return f"{prefix}_{unique_id}" From 9e67dfe3aea0c9a930aa91c440cf51efa20cc222 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 11:49:38 +0200 Subject: [PATCH 023/235] Minor edit (1) --- .../gtir_builtin_field_operator.py | 2 +- .../dace_fieldview/gtir_dataflow_builder.py | 10 ++-------- .../dace_fieldview/gtir_tasklet_codegen.py | 19 ++----------------- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index 27d01c0800..c58cab4a18 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -27,8 +27,8 @@ class GtirBuiltinAsFieldOp(GtirTaskletCodegen): _stencil: itir.Lambda - _domain: dict[Dimension, tuple[str, str]] _args: Sequence[GtirTaskletCodegen] + _domain: dict[Dimension, tuple[str, str]] _field_type: ts.FieldType def __init__( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 4f97615768..ed5e2c5540 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from dataclasses import dataclass from typing import Callable, Sequence import dace @@ -36,20 +37,13 @@ from gt4py.next.type_system import type_specifications as ts +@dataclass(frozen=True) class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" _sdfg: dace.SDFG _data_types: dict[str, ts.FieldType | ts.ScalarType] - def __init__( - self, - sdfg: dace.SDFG, - data_types: dict[str, ts.FieldType | ts.ScalarType], - ): - self._sdfg = sdfg - self._data_types = data_types - def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: domain = [] assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 7e03bb9cd0..73f910a39c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses +from dataclasses import dataclass from typing import Any, final import dace @@ -80,26 +80,11 @@ } -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - +@dataclass(frozen=True) class GtirTaskletCodegen(codegen.TemplatedGenerator): _sdfg: dace.SDFG _state: dace.SDFGState - def __init__(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: - self._sdfg = sdfg - self._state = state - @final def __call__( self, From 4b4109e74450840bb5935383e38f47a3ca71ff03 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 16:18:45 +0200 Subject: [PATCH 024/235] Fix state handling --- .../gtir_builtin_field_operator.py | 2 +- .../dace_fieldview/gtir_builtin_select.py | 4 +- .../dace_fieldview/gtir_dataflow_builder.py | 21 +++++----- .../dace_fieldview/gtir_tasklet_codegen.py | 4 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 6 +-- .../runners_tests/test_dace_fieldview.py | 40 +++++++++++-------- 6 files changed, 44 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index c58cab4a18..2d4ac20ade 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -69,7 +69,7 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_memlets: dict[str, dace.Memlet] = {} assert len(self._args) == len(self._stencil.params) for arg, param in zip(self._args, self._stencil.params): - arg_nodes, _ = arg() + arg_nodes = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] connector = str(param.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py index b9d46ec856..51348fea40 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py @@ -37,8 +37,8 @@ def __init__( self._false_br_builder = false_br_builder def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - true_br_args, _ = self._true_br_builder() - false_br_args, _ = self._false_br_builder() + true_br_args = self._true_br_builder() + false_br_args = self._false_br_builder() assert len(true_br_args) == len(false_br_args) output_nodes = [] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index ed5e2c5540..a1a35cd970 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -60,17 +60,17 @@ def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: def visit_expression( self, node: itir.Expr, head_state: dace.SDFGState - ) -> tuple[list[dace.nodes.AccessNode], dace.SDFGState]: + ) -> list[dace.nodes.AccessNode]: expr_builder = self.visit(node, state=head_state) assert callable(expr_builder) - results, head_state = expr_builder() + results = expr_builder() expressions_nodes = [] for node, _ in results: assert isinstance(node, dace.nodes.AccessNode) expressions_nodes.append(node) - return expressions_nodes, head_state + return expressions_nodes def visit_symbolic(self, node: itir.Expr) -> str: state = self._sdfg.start_state @@ -110,23 +110,26 @@ def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: cond = self.visit_symbolic(fun_node.args[0]) # use join state to terminate the dataflow on a single exit node - join_state = self._sdfg.add_state(state.label + "_join") + select_state = self._sdfg.add_state_before(state, state.label + "_select") + self._sdfg.remove_edge(self._sdfg.out_edges(select_state)[0]) # expect true branch as second argument true_state = self._sdfg.add_state(state.label + "_true_branch") - self._sdfg.add_edge(state, true_state, dace.InterstateEdge(condition=cond)) - self._sdfg.add_edge(true_state, join_state, dace.InterstateEdge()) + self._sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) + self._sdfg.add_edge(true_state, state, dace.InterstateEdge()) true_br_callable = self.visit(fun_node.args[1], state=true_state) # and false branch as third argument false_state = self._sdfg.add_state(state.label + "_false_branch") - self._sdfg.add_edge(state, false_state, dace.InterstateEdge(condition=f"not {cond}")) - self._sdfg.add_edge(false_state, join_state, dace.InterstateEdge()) + self._sdfg.add_edge( + select_state, false_state, dace.InterstateEdge(condition=f"not {cond}") + ) + self._sdfg.add_edge(false_state, state, dace.InterstateEdge()) false_br_callable = self.visit(fun_node.args[2], state=false_state) return Select( sdfg=self._sdfg, - state=join_state, + state=state, true_br_builder=true_br_callable, false_br_builder=false_br_callable, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 73f910a39c..37f7dd17b6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -88,13 +88,13 @@ class GtirTaskletCodegen(codegen.TemplatedGenerator): @final def __call__( self, - ) -> tuple[list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]], dace.SDFGState]: + ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: """ "Creates the dataflow representing the given GTIR builtin. Returns a list of connections, where each connectio is defined as: tuple(node, connector_name) """ - return self._build(), self._state + return self._build() @final def _add_local_storage( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 0969e94bbd..b2448e8b14 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -148,7 +148,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: self.visit(stmt, sdfg=sdfg, state=head_state) # sanity check below: each statement should have a single exit state -- aka no branches sink_states = sdfg.sink_nodes() - assert len(sink_states) == 1 + assert sink_states == [head_state] head_state = sink_states[0] sdfg.validate() @@ -162,12 +162,12 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) """ dataflow_builder = DataflowBuilder(sdfg, self._data_types) - expr_nodes, state = dataflow_builder.visit_expression(stmt.expr, state) + expr_nodes = dataflow_builder.visit_expression(stmt.expr, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field target_builder = DataflowBuilder(sdfg, self._data_types) - target_nodes, state = target_builder.visit_expression(stmt.target, state) + target_nodes = target_builder.visit_expression(stmt.target, state) assert len(expr_nodes) == len(target_nodes) domain = dataflow_builder.visit_domain(stmt.domain) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index ebe5366ccb..c740416974 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -224,22 +224,30 @@ def test_gtir_select(): body=[ itir.SetAt( expr=im.call( - im.call("select")( - im.deref("cond"), - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )("x", "y"), - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )("y", "w"), + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) - )(), + )( + "x", + im.call( + im.call("select")( + im.deref("cond"), + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1)), + domain, + ) + )("y"), + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1)), + domain, + ) + )("w"), + ) + )(), + ), domain=domain, target=itir.SymRef(id="z"), ) @@ -268,4 +276,4 @@ def test_gtir_select(): for s in [False, True]: sdfg(cond=s, x=a, y=b, w=c, z=d, **FSYMBOLS) - assert np.allclose(d, (a + b) if s else (b + c)) + assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) From 495fd0aefbf3efb5e24e3f547bb8448c29a46de2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 17:20:00 +0200 Subject: [PATCH 025/235] Edit comments based on review --- .../gtir_builtin_field_operator.py | 4 +++- .../dace_fieldview/gtir_dataflow_builder.py | 22 +++++++++---------- .../dace_fieldview/gtir_tasklet_codegen.py | 11 +++++++--- .../runners/dace_fieldview/gtir_to_sdfg.py | 17 ++++++-------- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index 2d4ac20ade..352eeda43e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -47,7 +47,8 @@ def __init__( self._field_type = ts.FieldType([dim for dim, _, _ in domain], field_dtype) def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - # generate the python code for this stencil + # generate a tasklet implementing the stencil and represent the field operator + # as a mapped tasklet, which will range over the field domain. output_connector = "__out" tlet_code = "{var} = {code}".format( var=output_connector, code=self.visit(self._stencil.expr) @@ -86,6 +87,7 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) else: memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self._sdfg)) + # TODO: assume for now that all stencils (aka tasklets) perform single element access memlet.volume = 1 input_memlets[connector] = memlet else: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index a1a35cd970..2e9cc707e8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -79,14 +79,14 @@ def visit_symbolic(self, node: itir.Expr) -> str: def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: if cpm.is_call_to(node.fun, "as_fieldop"): - fun_node = node.fun - assert len(fun_node.args) == 2 + assert len(node.fun.args) == 2 + stencil_expr, domain_expr = node.fun.args # expect stencil (represented as a lambda function) as first argument - assert isinstance(fun_node.args[0], itir.Lambda) + assert isinstance(stencil_expr, itir.Lambda) # the domain of the field operator is passed as second argument - assert isinstance(fun_node.args[1], itir.FunCall) - field_domain = self.visit_domain(fun_node.args[1]) + assert isinstance(domain_expr, itir.FunCall) + field_domain = self.visit_domain(domain_expr) stencil_args = [self.visit(arg, state=state) for arg in node.args] # add local storage to compute the field operator over the given domain @@ -96,18 +96,18 @@ def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: return AsFieldOp( sdfg=self._sdfg, state=state, - stencil=fun_node.args[0], + stencil=stencil_expr, domain=field_domain, args=stencil_args, field_dtype=node_type, ) elif cpm.is_call_to(node.fun, "select"): - fun_node = node.fun - assert len(fun_node.args) == 3 + assert len(node.fun.args) == 3 + cond_expr, true_expr, false_expr = node.fun.args # expect condition as first argument - cond = self.visit_symbolic(fun_node.args[0]) + cond = self.visit_symbolic(cond_expr) # use join state to terminate the dataflow on a single exit node select_state = self._sdfg.add_state_before(state, state.label + "_select") @@ -117,7 +117,7 @@ def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: true_state = self._sdfg.add_state(state.label + "_true_branch") self._sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) self._sdfg.add_edge(true_state, state, dace.InterstateEdge()) - true_br_callable = self.visit(fun_node.args[1], state=true_state) + true_br_callable = self.visit(true_expr, state=true_state) # and false branch as third argument false_state = self._sdfg.add_state(state.label + "_false_branch") @@ -125,7 +125,7 @@ def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: select_state, false_state, dace.InterstateEdge(condition=f"not {cond}") ) self._sdfg.add_edge(false_state, state, dace.InterstateEdge()) - false_br_callable = self.visit(fun_node.args[2], state=false_state) + false_br_callable = self.visit(false_expr, state=false_state) return Select( sdfg=self._sdfg, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 37f7dd17b6..67acf6d176 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -89,10 +89,13 @@ class GtirTaskletCodegen(codegen.TemplatedGenerator): def __call__( self, ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - """ "Creates the dataflow representing the given GTIR builtin. + """Creates the dataflow representing the given GTIR builtin. - Returns a list of connections, where each connectio is defined as: - tuple(node, connector_name) + Returns a list of SDFG nodes and the associated GT4Py data types: + tuple(node, data_type) + + The GT4Py data type is useful in the case of fields, because it provides + information on the field domain (e.g. order of dimensions, types of dimensions). """ return self._build() @@ -140,6 +143,8 @@ def visit_FunCall(self, node: itir.FunCall) -> str: @final def visit_Lambda(self, node: itir.Lambda) -> Any: # This visitor class should never encounter `itir.Lambda` expressions + # because a lambda represents a stencil, which translates from iterator to value. + # In fieldview, lambdas should only be arguments to field operators (`as_field_op`). raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") @final diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index b2448e8b14..e688ec6300 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -35,15 +35,12 @@ class GtirToSDFG(eve.NodeVisitor): This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. - Each statement is translated to a taskgraph inside a separate state. The parent SDFG and - the translation state define the statement context, implemented by `FieldviewRegion`. - Statement states are chained one after the other: potential concurrency between states should be - extracted by the DaCe SDFG transformations. - The program translation keeps track of entry and exit states: each statement is supposed to extend - the SDFG but maintain the property of single exit state (that is no branching on leaf nodes). - Branching is allowed within the context of one statement, but in that case the statement should + Each statement is translated to a taskgraph inside a separate state. Statement states are chained + one after the other: concurrency between states should be extracted by the DaCe SDFG transformations. + The translator will extend the SDFG while preserving the property of single exit state: + branching is allowed within the context of one statement, but in that case the statement should terminate with a join state; the join state will represent the head state for next statement, - that is where to continue building the SDFG. + from where to continue building the SDFG. """ _param_types: list[ts.TypeSpec] @@ -146,7 +143,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") self.visit(stmt, sdfg=sdfg, state=head_state) - # sanity check below: each statement should have a single exit state -- aka no branches + # sanity check: each statement should result in a single exit state, i.e. only internal branches sink_states = sdfg.sink_nodes() assert sink_states == [head_state] head_state = sink_states[0] @@ -155,7 +152,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: return sdfg def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: - """Visits a statement expression and writes the local result to some external storage. + """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of taskgraph writing to local (aka transient) storage. The translation of `SetAt` ensures that the result is written to the external storage. From 0085194db8e6292d16b5573c0d9965e71536db07 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 30 Apr 2024 17:33:47 +0200 Subject: [PATCH 026/235] Add test case for nested select --- .../runners/dace_fieldview/gtir_to_sdfg.py | 2 +- .../runners_tests/test_dace_fieldview.py | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index e688ec6300..68febe5465 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -145,7 +145,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: self.visit(stmt, sdfg=sdfg, state=head_state) # sanity check: each statement should result in a single exit state, i.e. only internal branches sink_states = sdfg.sink_nodes() - assert sink_states == [head_state] + assert len(sink_states) == 1 head_state = sink_states[0] sdfg.validate() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index c740416974..f15132a608 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -277,3 +277,77 @@ def test_gtir_select(): for s in [False, True]: sdfg(cond=s, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) + + +def test_gtir_select_nested(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + ) + testee = itir.Program( + id="select_nested", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="z"), + itir.Sym(id="cond_1"), + itir.Sym(id="cond_2"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("select")( + im.deref("cond_1"), + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1)), + domain, + ) + )("x"), + im.call( + im.call("select")( + im.deref("cond_2"), + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 2)), + domain, + ) + )("x"), + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 3)), + domain, + ) + )("x"), + ) + )(), + ) + )(), + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg_genenerator = FieldviewGtirToSDFG( + [ + FTYPE, + FTYPE, + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.INT32), + ], + OFFSET_PROVIDERS, + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + for s1 in [False, True]: + for s2 in [False, True]: + sdfg(cond_1=s1, cond_2=s2, x=a, z=b, **FSYMBOLS) + assert np.allclose(b, (a + 1) if s1 else (a + 2) if s2 else (a + 3)) From 41e2a448a4a761b98f989f5d69073de34b544080 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 1 May 2024 11:04:27 +0200 Subject: [PATCH 027/235] Separate builtin translation from driver logic --- .../gtir_builtin_field_operator.py | 63 ++++--- .../dace_fieldview/gtir_builtin_select.py | 43 +++-- .../dace_fieldview/gtir_builtin_symbol_ref.py | 17 +- .../runners/dace_fieldview/gtir_builtins.py | 30 ++++ .../dace_fieldview/gtir_dataflow_builder.py | 158 ++++++++++-------- .../dace_fieldview/gtir_tasklet_codegen.py | 47 +----- .../runners/dace_fieldview/gtir_to_sdfg.py | 22 +-- 7 files changed, 210 insertions(+), 170 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index 352eeda43e..e4fa6ea832 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -13,63 +13,80 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Sequence +from typing import Callable import dace from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GtirTaskletCodegen, +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( + GtirDataflowBuilder, ) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinAsFieldOp(GtirTaskletCodegen): - _stencil: itir.Lambda - _args: Sequence[GtirTaskletCodegen] - _domain: dict[Dimension, tuple[str, str]] +class GtirBuiltinAsFieldOp(GtirDataflowBuilder): + _stencil_expr: itir.Lambda + _stencil_args: list[Callable] + _field_domain: dict[Dimension, tuple[str, str]] _field_type: ts.FieldType def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - stencil: itir.Lambda, - domain: Sequence[tuple[Dimension, str, str]], - args: Sequence[GtirTaskletCodegen], - field_dtype: ts.ScalarType, + data_types: dict[str, ts.FieldType | ts.ScalarType], + node: itir.FunCall, + stencil_args: list[Callable], ): - super().__init__(sdfg, state) - self._stencil = stencil - self._args = args - self._domain = {dim: (lb, ub) for dim, lb, ub in domain} - self._field_type = ts.FieldType([dim for dim, _, _ in domain], field_dtype) + super().__init__(sdfg, state, data_types) + + assert cpm.is_call_to(node.fun, "as_fieldop") + assert len(node.fun.args) == 2 + stencil_expr, domain_expr = node.fun.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, itir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, itir.FunCall) + + # visit field domain + domain = self.visit_domain(domain_expr) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + + self._field_domain = {dim: (lb, ub) for dim, lb, ub in domain} + self._field_type = ts.FieldType([dim for dim, _, _ in domain], node_type) + self._stencil_expr = stencil_expr + self._stencil_args = stencil_args def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # generate a tasklet implementing the stencil and represent the field operator # as a mapped tasklet, which will range over the field domain. output_connector = "__out" tlet_code = "{var} = {code}".format( - var=output_connector, code=self.visit(self._stencil.expr) + var=output_connector, code=self.visit_symbolic(self._stencil_expr.expr) ) # allocate local (aka transient) storage for the field field_shape = [ # diff between upper and lower bound - f"{self._domain[dim][1]} - {self._domain[dim][0]}" + f"{self._field_domain[dim][1]} - {self._field_domain[dim][0]}" for dim in self._field_type.dims ] field_node = self._add_local_storage(self._field_type, field_shape) # create map range corresponding to the field operator domain - map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self._domain.items()} + map_ranges = { + f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self._field_domain.items() + } - # visit expressions passed as arguments to this stencil input_nodes: dict[str, dace.nodes.AccessNode] = {} input_memlets: dict[str, dace.Memlet] = {} - assert len(self._args) == len(self._stencil.params) - for arg, param in zip(self._args, self._stencil.params): + assert len(self._stencil_args) == len(self._stencil_expr.params) + for arg, param in zip(self._stencil_args, self._stencil_expr.params): arg_nodes = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] @@ -79,7 +96,7 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_nodes[arg_node.data] = arg_node if isinstance(arg_type, ts.FieldType): # support either single element access (general case) or full array shape - is_scalar = all(dim in self._domain for dim in arg_type.dims) + is_scalar = all(dim in self._field_domain for dim in arg_type.dims) if is_scalar: subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) input_memlets[connector] = dace.Memlet( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py index 51348fea40..731da66e02 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py @@ -15,26 +15,49 @@ import dace -from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GtirTaskletCodegen, +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( + GtirDataflowBuilder, ) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSelect(GtirTaskletCodegen): - _true_br_builder: GtirTaskletCodegen - _false_br_builder: GtirTaskletCodegen +class GtirBuiltinSelect(GtirDataflowBuilder): + _true_br_builder: GtirDataflowBuilder + _false_br_builder: GtirDataflowBuilder def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - true_br_builder: GtirTaskletCodegen, - false_br_builder: GtirTaskletCodegen, + data_types: dict[str, ts.FieldType | ts.ScalarType], + node: itir.FunCall, ): - super().__init__(sdfg, state) - self._true_br_builder = true_br_builder - self._false_br_builder = false_br_builder + super().__init__(sdfg, state, data_types) + + assert cpm.is_call_to(node.fun, "select") + assert len(node.fun.args) == 3 + cond_expr, true_expr, false_expr = node.fun.args + + # expect condition as first argument + cond = self.visit_symbolic(cond_expr) + + # use join state to terminate the dataflow on a single exit node + select_state = sdfg.add_state_before(state, state.label + "_select") + sdfg.remove_edge(sdfg.out_edges(select_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + self._true_br_builder = self.fork(true_state).visit(true_expr) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond}")) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + self._false_br_builder = self.fork(false_state).visit(false_expr) def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: true_br_args = self._true_br_builder() diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py index e1ac0cca3b..f950bf2e19 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py @@ -17,14 +17,15 @@ import dace -from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GtirTaskletCodegen, +from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( + GtirDataflowBuilder, ) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSymbolRef(GtirTaskletCodegen): +class GtirBuiltinSymbolRef(GtirDataflowBuilder): _sym_name: str _sym_type: ts.FieldType | ts.ScalarType @@ -32,12 +33,14 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - sym_name: str, - data_type: ts.FieldType | ts.ScalarType, + data_types: dict[str, ts.FieldType | ts.ScalarType], + node: itir.SymRef, ): - super().__init__(sdfg, state) + super().__init__(sdfg, state, data_types) + sym_name = str(node.id) + assert sym_name in self._data_types self._sym_name = sym_name - self._sym_type = data_type + self._sym_type = self._data_types[sym_name] def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: access_nodes = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py new file mode 100644 index 0000000000..198dab2763 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py @@ -0,0 +1,30 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_field_operator import ( + GtirBuiltinAsFieldOp as AsFieldOp, +) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_select import ( + GtirBuiltinSelect as Select, +) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_symbol_ref import ( + GtirBuiltinSymbolRef as SymbolRef, +) + + +__all__ = [ + "AsFieldOp", + "Select", + "SymbolRef", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 2e9cc707e8..9035307e68 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from typing import Callable, Sequence +from typing import Any, Callable, final import dace @@ -22,18 +22,10 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_field_operator import ( - GtirBuiltinAsFieldOp as AsFieldOp, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_select import ( - GtirBuiltinSelect as Select, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_symbol_ref import ( - GtirBuiltinSymbolRef as SymbolRef, -) from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts @@ -42,9 +34,48 @@ class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" _sdfg: dace.SDFG + _state: dace.SDFGState _data_types: dict[str, ts.FieldType | ts.ScalarType] - def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: + @final + def __call__( + self, + ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + """Creates the dataflow representing the given GTIR builtin. + + Returns a list of SDFG nodes and the associated GT4Py data types: + tuple(node, data_type) + + The GT4Py data type is useful in the case of fields, because it provides + information on the field domain (e.g. order of dimensions, types of dimensions). + """ + return self._build() + + @final + def _add_local_storage( + self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + ) -> dace.nodes.AccessNode: + name = unique_name("var") + if isinstance(data_type, ts.FieldType): + assert len(data_type.dims) == len(shape) + dtype = as_dace_type(data_type.dtype) + name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) + else: + assert len(shape) == 0 + dtype = as_dace_type(data_type) + name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + return self._state.add_access(name) + + def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + raise NotImplementedError + + def _visit_node(self, node: itir.FunCall) -> None: + raise NotImplementedError + + def fork(self, state: dace.SDFGState) -> "GtirDataflowBuilder": + return GtirDataflowBuilder(self._sdfg, state, self._data_types) + + def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: domain = [] assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) for named_range in node.args: @@ -59,9 +90,9 @@ def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]: return domain def visit_expression( - self, node: itir.Expr, head_state: dace.SDFGState - ) -> list[dace.nodes.AccessNode]: - expr_builder = self.visit(node, state=head_state) + self, node: itir.Expr + ) -> tuple[dace.SDFGState, list[dace.nodes.AccessNode]]: + expr_builder = self.visit(node) assert callable(expr_builder) results = expr_builder() @@ -70,75 +101,54 @@ def visit_expression( assert isinstance(node, dace.nodes.AccessNode) expressions_nodes.append(node) - return expressions_nodes + # sanity check: each statement should result in a single exit state, i.e. only internal branches + sink_states = self._sdfg.sink_nodes() + assert len(sink_states) == 1 + head_state = sink_states[0] + + return head_state, expressions_nodes def visit_symbolic(self, node: itir.Expr) -> str: - state = self._sdfg.start_state - codegen = GtirTaskletCodegen(self._sdfg, state) - return codegen.visit(node) + return GtirTaskletCodegen().visit(node) + + def visit_FunCall(self, node: itir.FunCall) -> Callable: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins + + arg_builders: list[Callable] = [] + for arg in node.args: + arg_builder = self.visit(arg) + assert callable(arg_builder) + arg_builders.append(arg_builder) - def visit_FunCall(self, node: itir.FunCall, state: dace.SDFGState) -> Callable: if cpm.is_call_to(node.fun, "as_fieldop"): - assert len(node.fun.args) == 2 - stencil_expr, domain_expr = node.fun.args - # expect stencil (represented as a lambda function) as first argument - assert isinstance(stencil_expr, itir.Lambda) - # the domain of the field operator is passed as second argument - assert isinstance(domain_expr, itir.FunCall) - - field_domain = self.visit_domain(domain_expr) - stencil_args = [self.visit(arg, state=state) for arg in node.args] - - # add local storage to compute the field operator over the given domain - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - - return AsFieldOp( - sdfg=self._sdfg, - state=state, - stencil=stencil_expr, - domain=field_domain, - args=stencil_args, - field_dtype=node_type, + return gtir_builtins.AsFieldOp( + self._sdfg, + self._state, + self._data_types, + node, + arg_builders, ) elif cpm.is_call_to(node.fun, "select"): - assert len(node.fun.args) == 3 - cond_expr, true_expr, false_expr = node.fun.args - - # expect condition as first argument - cond = self.visit_symbolic(cond_expr) - - # use join state to terminate the dataflow on a single exit node - select_state = self._sdfg.add_state_before(state, state.label + "_select") - self._sdfg.remove_edge(self._sdfg.out_edges(select_state)[0]) - - # expect true branch as second argument - true_state = self._sdfg.add_state(state.label + "_true_branch") - self._sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) - self._sdfg.add_edge(true_state, state, dace.InterstateEdge()) - true_br_callable = self.visit(true_expr, state=true_state) - - # and false branch as third argument - false_state = self._sdfg.add_state(state.label + "_false_branch") - self._sdfg.add_edge( - select_state, false_state, dace.InterstateEdge(condition=f"not {cond}") - ) - self._sdfg.add_edge(false_state, state, dace.InterstateEdge()) - false_br_callable = self.visit(false_expr, state=false_state) - - return Select( - sdfg=self._sdfg, - state=state, - true_br_builder=true_br_callable, - false_br_builder=false_br_callable, + assert len(arg_builders) == 0 + return gtir_builtins.Select( + self._sdfg, + self._state, + self._data_types, + node, ) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - def visit_SymRef(self, node: itir.SymRef, state: dace.SDFGState) -> Callable: - arg_name = str(node.id) - assert arg_name in self._data_types - arg_type = self._data_types[arg_name] - return SymbolRef(self._sdfg, state, arg_name, arg_type) + @final + def visit_Lambda(self, node: itir.Lambda) -> Any: + # This visitor class should never encounter `itir.Lambda` expressions + # because a lambda represents a stencil, which translates from iterator to value. + # In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") + + def visit_SymRef(self, node: itir.SymRef) -> Callable: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins + + return gtir_builtins.SymbolRef(self._sdfg, self._state, self._data_types, node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 67acf6d176..0ac029e20b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -13,16 +13,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from typing import Any, final +from typing import final -import dace import numpy as np from gt4py.eve import codegen from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name -from gt4py.next.type_system import type_specifications as ts _MATH_BUILTINS_MAPPING = { @@ -82,40 +79,7 @@ @dataclass(frozen=True) class GtirTaskletCodegen(codegen.TemplatedGenerator): - _sdfg: dace.SDFG - _state: dace.SDFGState - - @final - def __call__( - self, - ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - """Creates the dataflow representing the given GTIR builtin. - - Returns a list of SDFG nodes and the associated GT4Py data types: - tuple(node, data_type) - - The GT4Py data type is useful in the case of fields, because it provides - information on the field domain (e.g. order of dimensions, types of dimensions). - """ - return self._build() - - @final - def _add_local_storage( - self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] - ) -> dace.nodes.AccessNode: - name = unique_name("var") - if isinstance(data_type, ts.FieldType): - assert len(data_type.dims) == len(shape) - dtype = as_dace_type(data_type.dtype) - name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) - else: - assert len(shape) == 0 - dtype = as_dace_type(data_type) - name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) - return self._state.add_access(name) - - def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - raise NotImplementedError + """Stateless class to visit pure tasklet expressions.""" def _visit_deref(self, node: itir.FunCall) -> str: assert len(node.args) == 1 @@ -140,13 +104,6 @@ def visit_FunCall(self, node: itir.FunCall) -> str: raise NotImplementedError(f"'{builtin_name}' not implemented.") raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - @final - def visit_Lambda(self, node: itir.Lambda) -> Any: - # This visitor class should never encounter `itir.Lambda` expressions - # because a lambda represents a stencil, which translates from iterator to value. - # In fieldview, lambdas should only be arguments to field operators (`as_field_op`). - raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") - @final def visit_Literal(self, node: itir.Literal) -> str: return node.value diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 68febe5465..80c5839a82 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -142,29 +142,27 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # visit one statement at a time and put it into separate state for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - self.visit(stmt, sdfg=sdfg, state=head_state) - # sanity check: each statement should result in a single exit state, i.e. only internal branches - sink_states = sdfg.sink_nodes() - assert len(sink_states) == 1 - head_state = sink_states[0] + head_state = self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() return sdfg - def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + def visit_SetAt( + self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState + ) -> dace.SDFGState: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of taskgraph writing to local (aka transient) storage. The translation of `SetAt` ensures that the result is written to the external storage. """ - dataflow_builder = DataflowBuilder(sdfg, self._data_types) - expr_nodes = dataflow_builder.visit_expression(stmt.expr, state) + dataflow_builder = DataflowBuilder(sdfg, state, self._data_types) + head_state, expr_nodes = dataflow_builder.visit_expression(stmt.expr) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - target_builder = DataflowBuilder(sdfg, self._data_types) - target_nodes = target_builder.visit_expression(stmt.target, state) + target_builder = DataflowBuilder(sdfg, head_state, self._data_types) + head_state, target_nodes = target_builder.visit_expression(stmt.target) assert len(expr_nodes) == len(target_nodes) domain = dataflow_builder.visit_domain(stmt.domain) @@ -184,8 +182,10 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) assert len(domain) == 0 subset = "0" - state.add_nedge( + head_state.add_nedge( expr_node, target_node, dace.Memlet(data=target_node.data, subset=subset), ) + + return head_state From 7148c5f4c11e6187f47cd66640565313ff395555 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 2 May 2024 09:01:26 +0200 Subject: [PATCH 028/235] Improve code comments --- .../gtir_builtin_field_operator.py | 11 ++-- .../dace_fieldview/gtir_builtin_select.py | 18 ++++++- .../dace_fieldview/gtir_builtin_symbol_ref.py | 9 +++- .../runners/dace_fieldview/gtir_builtins.py | 1 + .../dace_fieldview/gtir_dataflow_builder.py | 51 +++++++++++++++---- .../dace_fieldview/gtir_tasklet_codegen.py | 7 ++- .../runners/dace_fieldview/gtir_to_sdfg.py | 36 ++++++++----- .../runners/dace_fieldview/utility.py | 8 +++ 8 files changed, 109 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py index e4fa6ea832..f13c03b7e4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py @@ -27,6 +27,8 @@ class GtirBuiltinAsFieldOp(GtirDataflowBuilder): + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + _stencil_expr: itir.Lambda _stencil_args: list[Callable] _field_domain: dict[Dimension, tuple[str, str]] @@ -63,14 +65,14 @@ def __init__( self._stencil_args = stencil_args def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - # generate a tasklet implementing the stencil and represent the field operator - # as a mapped tasklet, which will range over the field domain. + # generate a tasklet node implementing the stencil function and represent + # the field operator as a mapped tasklet, which will range over the field domain output_connector = "__out" tlet_code = "{var} = {code}".format( var=output_connector, code=self.visit_symbolic(self._stencil_expr.expr) ) - # allocate local (aka transient) storage for the field + # allocate local temporary storage for the result field field_shape = [ # diff between upper and lower bound f"{self._field_domain[dim][1]} - {self._field_domain[dim][0]}" @@ -104,7 +106,8 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) else: memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self._sdfg)) - # TODO: assume for now that all stencils (aka tasklets) perform single element access + # set volume to 1 because the stencil function always performs single element access + # TODO: check validity of this assumption memlet.volume = 1 input_memlets[connector] = memlet else: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py index 731da66e02..05e85de65a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py @@ -24,6 +24,8 @@ class GtirBuiltinSelect(GtirDataflowBuilder): + """Generates the dataflow subgraph for the `select` builtin function.""" + _true_br_builder: GtirDataflowBuilder _false_br_builder: GtirDataflowBuilder @@ -43,7 +45,21 @@ def __init__( # expect condition as first argument cond = self.visit_symbolic(cond_expr) - # use join state to terminate the dataflow on a single exit node + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | select | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # select_state = sdfg.add_state_before(state, state.label + "_select") sdfg.remove_edge(sdfg.out_edges(select_state)[0]) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py index f950bf2e19..ce5e261dfc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py @@ -26,6 +26,8 @@ class GtirBuiltinSymbolRef(GtirDataflowBuilder): + """Generates the dataflow subgraph for a `itir.SymRef` node.""" + _sym_name: str _sym_type: ts.FieldType | ts.ScalarType @@ -43,6 +45,7 @@ def __init__( self._sym_type = self._data_types[sym_name] def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: + """Returns, if present, the access node in current state for the data symbol.""" access_nodes = [ node for node in self._state.nodes() @@ -56,14 +59,16 @@ def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: sym_node = self._get_access_node() if sym_node: - # share access node in same state + # if already present in current state, share access node pass elif isinstance(self._sym_type, ts.FieldType): + # add access node to current state sym_node = self._state.add_access(self._sym_name) else: - # scalar symbols are passed to the SDFG as symbols + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node assert self._sym_name in self._sdfg.symbols tasklet_node = self._state.add_tasklet( f"get_{self._sym_name}", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py index 198dab2763..61e4a21915 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py @@ -23,6 +23,7 @@ ) +# export short names of translation classes for GTIR builtin functions __all__ = [ "AsFieldOp", "Select", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 9035307e68..c96a026655 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -41,13 +41,11 @@ class GtirDataflowBuilder(eve.NodeVisitor): def __call__( self, ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - """Creates the dataflow representing the given GTIR builtin. + """The callable interface is used by the caller to build the dataflow graph. - Returns a list of SDFG nodes and the associated GT4Py data types: - tuple(node, data_type) - - The GT4Py data type is useful in the case of fields, because it provides - information on the field domain (e.g. order of dimensions, types of dimensions). + It allows to build the dataflow graph inside a given state starting + from the innermost nodes, by propagating the intermediate results + as access nodes to temporary local storage. """ return self._build() @@ -55,6 +53,7 @@ def __call__( def _add_local_storage( self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] ) -> dace.nodes.AccessNode: + """Allocates temporary storage to be used in the local scope for intermediate results.""" name = unique_name("var") if isinstance(data_type, ts.FieldType): assert len(data_type.dims) == len(shape) @@ -67,15 +66,28 @@ def _add_local_storage( return self._state.add_access(name) def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - raise NotImplementedError + """Creates the dataflow subgraph representing a given GTIR builtin. + + This method is used by derived classes of `GtirDataflowBuilder`, + which build a specialized subgraph for a certain GTIR builtin. - def _visit_node(self, node: itir.FunCall) -> None: + Returns a list of SDFG nodes and the associated GT4Py data type: + tuple(node, data_type) + + The GT4Py data type is useful in the case of fields, because it provides + information on the field domain (e.g. order of dimensions, types of dimensions). + """ raise NotImplementedError def fork(self, state: dace.SDFGState) -> "GtirDataflowBuilder": return GtirDataflowBuilder(self._sdfg, state, self._data_types) def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: + """ + Specialized visit method for domain expressions. + + Returns a list of dimensions and the corresponding range. + """ domain = [] assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) for named_range in node.args: @@ -92,6 +104,14 @@ def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: def visit_expression( self, node: itir.Expr ) -> tuple[dace.SDFGState, list[dace.nodes.AccessNode]]: + """ + Specialized visit method for fieldview expressions. + + This method represents the entry point to visit 'Stmt' expressions. + As such, it must preserve the property of single exit state in the SDFG. + + TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? + """ expr_builder = self.visit(node) assert callable(expr_builder) results = expr_builder() @@ -109,6 +129,13 @@ def visit_expression( return head_state, expressions_nodes def visit_symbolic(self, node: itir.Expr) -> str: + """ + Specialized visit method for pure stencil expressions. + + Returns a string represnting the Python code to be used as tasklet body. + TODO: should we return a list of code strings in case of tuple returns, + one for each output value? + """ return GtirTaskletCodegen().visit(node) def visit_FunCall(self, node: itir.FunCall) -> Callable: @@ -143,9 +170,11 @@ def visit_FunCall(self, node: itir.FunCall) -> Callable: @final def visit_Lambda(self, node: itir.Lambda) -> Any: - # This visitor class should never encounter `itir.Lambda` expressions - # because a lambda represents a stencil, which translates from iterator to value. - # In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ + This visitor class should never encounter `itir.Lambda` expressions + because a lambda represents a stencil, which translates from iterator to values. + In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") def visit_SymRef(self, node: itir.SymRef) -> Callable: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index 0ac029e20b..cb7ae68fcc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -79,7 +79,12 @@ @dataclass(frozen=True) class GtirTaskletCodegen(codegen.TemplatedGenerator): - """Stateless class to visit pure tasklet expressions.""" + """ + Stateless class to visit pure tasklet expressions. + + This visitor class is responsible for building the string representing + the Python code inside a tasklet node. + """ def _visit_deref(self, node: itir.FunCall) -> str: assert len(node.args) == 1 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 80c5839a82..163c98c7d2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -36,7 +36,7 @@ class GtirToSDFG(eve.NodeVisitor): This class is responsible for translation of `ir.Program`, that is the top level representation of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. Each statement is translated to a taskgraph inside a separate state. Statement states are chained - one after the other: concurrency between states should be extracted by the DaCe SDFG transformations. + one after the other: concurrency between states should be extracted by means of SDFG analysis. The translator will extend the SDFG while preserving the property of single exit state: branching is allowed within the context of one statement, but in that case the statement should terminate with a join state; the join state will represent the head state for next statement, @@ -86,6 +86,11 @@ def _make_array_shape_and_strides( return shape, strides def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: + """ + Add external storage (aka non-transient) for data containers passed as arguments to the SDFG. + + For fields, it allocates dace arrays, while scalars are stored as SDFG symbols. + """ assert isinstance(type_, (ts.FieldType, ts.ScalarType)) self._data_types[name] = type_ @@ -101,27 +106,30 @@ def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: sdfg.add_symbol(name, dtype) def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, str]: + """ + Add temporary storage (aka transient) for data containers used as GTIR temporaries. + + Assume all temporaries to be fields, therefore represented as dace arrays. + """ raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") return {} def visit_Program(self, node: itir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. - First, it will allocate array and scalar storage for external (aka non-transient) - and local (aka transient) data. The local data, at this stage, is used - for temporary declarations, which should be available everywhere in the SDFG - but not outside. - Then, all statements are translated, one after the other in separate states. + First, it will allocate field and scalar storage for global data. The storage + represents global data, available everywhere in the SDFG, either containing + external data (aka non-transient data) or temporary data (aka transient data). + The temporary data is global, therefore available everywhere in the SDFG + but not outside. Then, all statements are translated, one after the other. """ if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") sdfg = dace.SDFG(node.id) - - # we use entry/exit state to keep track of entry/exit point of graph execution entry_state = sdfg.add_state("program_entry", is_start_block=True) - # declarations of temporaries result in local (aka transient) array definitions in the SDFG + # declarations of temporaries result in transient array definitions in the SDFG if node.declarations: temp_symbols: dict[str, str] = {} for decl in node.declarations: @@ -134,14 +142,16 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: else: head_state = entry_state - # add global arrays (aka non-transient) to the SDFG + # add non-transient arrays and/or SDFG symbols for the program arguments assert len(node.params) == len(self._param_types) for param, type_ in zip(node.params, self._param_types): self._add_storage(sdfg, str(param.id), type_) - # visit one statement at a time and put it into separate state + # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") + # the statement could eventually modify the head state by appending new states + # however, it should preserve the property of single exit state (aka head state) head_state = self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() @@ -152,8 +162,8 @@ def visit_SetAt( ) -> dace.SDFGState: """Visits a `SetAt` statement expression and writes the local result to some external storage. - Each statement expression results in some sort of taskgraph writing to local (aka transient) storage. - The translation of `SetAt` ensures that the result is written to the external storage. + Each statement expression results in some sort of dataflow gragh writing to temporary storage. + The translation of `SetAt` ensures that the result is written to some global storage. """ dataflow_builder = DataflowBuilder(sdfg, state, self._data_types) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 2fcc12bc85..6799d2bb2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -21,6 +21,7 @@ def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: + """Converts GT4Py scalar type to corresponding DaCe type.""" if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ elif type_.kind == ts.ScalarKind.INT32: @@ -35,6 +36,12 @@ def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: + """ + Filter offset providers of type `Connectivity`. + + In other words, filter out the cartesian offset providers. + Returns a new dictionary containing only `Connectivity` values. + """ return { offset: table for offset, table in offset_provider.items() @@ -43,6 +50,7 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne def unique_name(prefix: str) -> str: + """Generate a string containing a unique integer id, which is updated incrementally.""" unique_id = getattr(unique_name, "_unique_id", 0) # static variable setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] From 452399dddcbdd8e38d827c5e89da1cf1fa4b62b0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 3 May 2024 16:25:51 +0200 Subject: [PATCH 029/235] Avoid inheritance: pass dataflow builder as arg to builtin translator --- .../__init__.py} | 6 +- .../gtir_builtin_field_operator.py | 59 ++++++------ .../gtir_builtin_select.py | 54 ++++++----- .../gtir_builtin_symbol_ref.py | 48 +++++----- .../gtir_builtins/gtir_builtin_translator.py | 72 +++++++++++++++ .../dace_fieldview/gtir_dataflow_builder.py | 89 ++++--------------- .../runners/dace_fieldview/gtir_to_sdfg.py | 29 +++--- 7 files changed, 194 insertions(+), 163 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins.py => gtir_builtins/__init__.py} (89%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{ => gtir_builtins}/gtir_builtin_field_operator.py (71%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{ => gtir_builtins}/gtir_builtin_select.py (69%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{ => gtir_builtins}/gtir_builtin_symbol_ref.py (62%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py similarity index 89% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py index 61e4a21915..293beec8f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py @@ -12,13 +12,13 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_field_operator import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( GtirBuiltinAsFieldOp as AsFieldOp, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_select import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_select import ( GtirBuiltinSelect as Select, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_symbol_ref import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_symbol_ref import ( GtirBuiltinSymbolRef as SymbolRef, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py similarity index 71% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index f13c03b7e4..6e19ea68ae 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -20,29 +20,34 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( + GtirBuiltinTranslator, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( GtirDataflowBuilder, ) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( + GtirTaskletCodegen, +) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinAsFieldOp(GtirDataflowBuilder): +class GtirBuiltinAsFieldOp(GtirBuiltinTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - _stencil_expr: itir.Lambda - _stencil_args: list[Callable] - _field_domain: dict[Dimension, tuple[str, str]] - _field_type: ts.FieldType + stencil_expr: itir.Lambda + stencil_args: list[Callable] + field_domain: dict[Dimension, tuple[str, str]] + field_type: ts.FieldType def __init__( self, - sdfg: dace.SDFG, + dataflow_builder: GtirDataflowBuilder, state: dace.SDFGState, - data_types: dict[str, ts.FieldType | ts.ScalarType], node: itir.FunCall, stencil_args: list[Callable], ): - super().__init__(sdfg, state, data_types) + super().__init__(state, dataflow_builder.sdfg) assert cpm.is_call_to(node.fun, "as_fieldop") assert len(node.fun.args) == 2 @@ -53,42 +58,40 @@ def __init__( assert isinstance(domain_expr, itir.FunCall) # visit field domain - domain = self.visit_domain(domain_expr) + domain = dataflow_builder.visit_domain(domain_expr) # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - self._field_domain = {dim: (lb, ub) for dim, lb, ub in domain} - self._field_type = ts.FieldType([dim for dim, _, _ in domain], node_type) - self._stencil_expr = stencil_expr - self._stencil_args = stencil_args + self.field_domain = {dim: (lb, ub) for dim, lb, ub in domain} + self.field_type = ts.FieldType([dim for dim, _, _ in domain], node_type) + self.stencil_expr = stencil_expr + self.stencil_args = stencil_args - def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # generate a tasklet node implementing the stencil function and represent # the field operator as a mapped tasklet, which will range over the field domain output_connector = "__out" tlet_code = "{var} = {code}".format( - var=output_connector, code=self.visit_symbolic(self._stencil_expr.expr) + var=output_connector, code=GtirTaskletCodegen().visit(self.stencil_expr.expr) ) # allocate local temporary storage for the result field field_shape = [ # diff between upper and lower bound - f"{self._field_domain[dim][1]} - {self._field_domain[dim][0]}" - for dim in self._field_type.dims + f"{self.field_domain[dim][1]} - {self.field_domain[dim][0]}" + for dim in self.field_type.dims ] - field_node = self._add_local_storage(self._field_type, field_shape) + field_node = self.add_local_storage(self.field_type, field_shape) # create map range corresponding to the field operator domain - map_ranges = { - f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self._field_domain.items() - } + map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} input_nodes: dict[str, dace.nodes.AccessNode] = {} input_memlets: dict[str, dace.Memlet] = {} - assert len(self._stencil_args) == len(self._stencil_expr.params) - for arg, param in zip(self._stencil_args, self._stencil_expr.params): + assert len(self.stencil_args) == len(self.stencil_expr.params) + for arg, param in zip(self.stencil_args, self.stencil_expr.params): arg_nodes = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] @@ -98,14 +101,14 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_nodes[arg_node.data] = arg_node if isinstance(arg_type, ts.FieldType): # support either single element access (general case) or full array shape - is_scalar = all(dim in self._field_domain for dim in arg_type.dims) + is_scalar = all(dim in self.field_domain for dim in arg_type.dims) if is_scalar: subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) input_memlets[connector] = dace.Memlet( data=arg_node.data, subset=subset, volume=1 ) else: - memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self._sdfg)) + memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) # set volume to 1 because the stencil function always performs single element access # TODO: check validity of this assumption memlet.volume = 1 @@ -114,12 +117,12 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_memlets[connector] = dace.Memlet(data=arg_node.data, subset="0") # assume tasklet with single output - output_index = ",".join(f"i_{dim.value}" for dim in self._field_type.dims) + output_index = ",".join(f"i_{dim.value}" for dim in self.field_type.dims) output_memlets = {output_connector: dace.Memlet(data=field_node.data, subset=output_index)} output_nodes = {field_node.data: field_node} # create a tasklet inside a parallel-map scope - self._state.add_mapped_tasklet( + self.head_state.add_mapped_tasklet( "tasklet", map_ranges, input_memlets, @@ -130,4 +133,4 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: external_edges=True, ) - return [(field_node, self._field_type)] + return [(field_node, self.field_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py similarity index 69% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index 05e85de65a..a9e0b8da9e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -13,37 +13,42 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Callable + import dace from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( + GtirBuiltinTranslator, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( GtirDataflowBuilder, ) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSelect(GtirDataflowBuilder): +class GtirBuiltinSelect(GtirBuiltinTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" - _true_br_builder: GtirDataflowBuilder - _false_br_builder: GtirDataflowBuilder + true_br_builder: Callable + false_br_builder: Callable def __init__( self, - sdfg: dace.SDFG, + dataflow_builder: GtirDataflowBuilder, state: dace.SDFGState, - data_types: dict[str, ts.FieldType | ts.ScalarType], node: itir.FunCall, ): - super().__init__(sdfg, state, data_types) + super().__init__(state, dataflow_builder.sdfg) + sdfg = dataflow_builder.sdfg assert cpm.is_call_to(node.fun, "select") assert len(node.fun.args) == 3 cond_expr, true_expr, false_expr = node.fun.args # expect condition as first argument - cond = self.visit_symbolic(cond_expr) + cond = dataflow_builder.visit_symbolic(cond_expr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: @@ -67,19 +72,28 @@ def __init__( true_state = sdfg.add_state(state.label + "_true_branch") sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) sdfg.add_edge(true_state, state, dace.InterstateEdge()) - self._true_br_builder = self.fork(true_state).visit(true_expr) + self.true_br_builder = dataflow_builder.visit(true_expr, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond}")) sdfg.add_edge(false_state, state, dace.InterstateEdge()) - self._false_br_builder = self.fork(false_state).visit(false_expr) + self.false_br_builder = dataflow_builder.visit(false_expr, head_state=false_state) - def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - true_br_args = self._true_br_builder() - false_br_args = self._false_br_builder() + def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + true_br_args = self.true_br_builder() + false_br_args = self.false_br_builder() assert len(true_br_args) == len(false_br_args) + # retrieve true/false states as predecessors of head state + branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) + assert len(branch_states) == 2 + if branch_states[0].label.endswith("_true_branch"): + true_state, false_state = branch_states + else: + false_state, true_state = branch_states + + output_nodes = [] for true_br, false_br in zip(true_br_args, false_br_args): true_br_node, true_br_type = true_br @@ -87,26 +101,26 @@ def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: false_br_node, false_br_type = false_br assert isinstance(false_br_node, dace.nodes.AccessNode) assert true_br_type == false_br_type - array_type = self._sdfg.arrays[true_br_node.data] - access_node = self._add_local_storage(true_br_type, array_type.shape) + array_type = self.sdfg.arrays[true_br_node.data] + access_node = self.add_local_storage(true_br_type, array_type.shape) output_nodes.append((access_node, true_br_type)) data_name = access_node.data - true_br_output_node = self._true_br_builder._state.add_access(data_name) - self._true_br_builder._state.add_nedge( + true_br_output_node = true_state.add_access(data_name) + true_state.add_nedge( true_br_node, true_br_output_node, dace.Memlet.from_array( - true_br_output_node.data, true_br_output_node.desc(self._sdfg) + true_br_output_node.data, true_br_output_node.desc(self.sdfg) ), ) - false_br_output_node = self._false_br_builder._state.add_access(data_name) - self._false_br_builder._state.add_nedge( + false_br_output_node = false_state.add_access(data_name) + false_state.add_nedge( false_br_node, false_br_output_node, dace.Memlet.from_array( - false_br_output_node.data, false_br_output_node.desc(self._sdfg) + false_br_output_node.data, false_br_output_node.desc(self.sdfg) ), ) return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py similarity index 62% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index ce5e261dfc..2a94b8dc87 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -18,6 +18,9 @@ import dace from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( + GtirBuiltinTranslator, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( GtirDataflowBuilder, ) @@ -25,62 +28,61 @@ from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSymbolRef(GtirDataflowBuilder): +class GtirBuiltinSymbolRef(GtirBuiltinTranslator): """Generates the dataflow subgraph for a `itir.SymRef` node.""" - _sym_name: str - _sym_type: ts.FieldType | ts.ScalarType + sym_name: str + sym_type: ts.FieldType | ts.ScalarType def __init__( self, - sdfg: dace.SDFG, + dataflow_builder: GtirDataflowBuilder, state: dace.SDFGState, - data_types: dict[str, ts.FieldType | ts.ScalarType], node: itir.SymRef, ): - super().__init__(sdfg, state, data_types) + super().__init__(state, dataflow_builder.sdfg) sym_name = str(node.id) - assert sym_name in self._data_types - self._sym_name = sym_name - self._sym_type = self._data_types[sym_name] + assert sym_name in dataflow_builder.data_types + self.sym_name = sym_name + self.sym_type = dataflow_builder.data_types[sym_name] def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: """Returns, if present, the access node in current state for the data symbol.""" access_nodes = [ node - for node in self._state.nodes() - if isinstance(node, dace.nodes.AccessNode) and node.data == self._sym_name + for node in self.head_state.nodes() + if isinstance(node, dace.nodes.AccessNode) and node.data == self.sym_name ] if len(access_nodes) == 0: return None assert len(access_nodes) == 1 return access_nodes[0] - def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: sym_node = self._get_access_node() if sym_node: # if already present in current state, share access node pass - elif isinstance(self._sym_type, ts.FieldType): + elif isinstance(self.sym_type, ts.FieldType): # add access node to current state - sym_node = self._state.add_access(self._sym_name) + sym_node = self.head_state.add_access(self.sym_name) else: # scalar symbols are passed to the SDFG as symbols: build tasklet node # to write the symbol to a scalar access node - assert self._sym_name in self._sdfg.symbols - tasklet_node = self._state.add_tasklet( - f"get_{self._sym_name}", + assert self.sym_name in self.sdfg.symbols + tasklet_node = self.head_state.add_tasklet( + f"get_{self.sym_name}", {}, {"__out"}, - f"__out = {self._sym_name}", + f"__out = {self.sym_name}", ) - name = f"{self._state.label}_var" - dtype = as_dace_type(self._sym_type) - sym_node = self._state.add_scalar(name, dtype, find_new_name=True, transient=True) - self._state.add_edge( + name = f"{self.head_state.label}_var" + dtype = as_dace_type(self.sym_type) + sym_node = self.head_state.add_scalar(name, dtype, find_new_name=True, transient=True) + self.head_state.add_edge( tasklet_node, "__out", sym_node, None, dace.Memlet(data=sym_node.data, subset="0") ) - return [(sym_node, self._sym_type)] + return [(sym_node, self.sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py new file mode 100644 index 0000000000..4104833214 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -0,0 +1,72 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from abc import abstractmethod +from dataclasses import dataclass +from typing import final +from gt4py import eve + +import dace +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name +from gt4py.next.type_system import type_specifications as ts + + +@dataclass(frozen=True) +class GtirBuiltinTranslator(eve.NodeVisitor): + head_state: dace.SDFGState + sdfg: dace.SDFG + + @final + def __call__( + self, + ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + """The callable interface is used to build the dataflow graph. + + It allows to build the dataflow graph inside a given state starting + from the innermost nodes, by propagating the intermediate results + as access nodes to temporary local storage. + """ + return self.build() + + @final + def add_local_storage( + self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + ) -> dace.nodes.AccessNode: + """Allocates temporary storage to be used in the local scope for intermediate results.""" + name = unique_name("var") + if isinstance(data_type, ts.FieldType): + assert len(data_type.dims) == len(shape) + dtype = as_dace_type(data_type.dtype) + name, _ = self.sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) + else: + assert len(shape) == 0 + dtype = as_dace_type(data_type) + name, _ = self.sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + return self.head_state.add_access(name) + + @abstractmethod + def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + """Creates the dataflow subgraph representing a given GTIR builtin. + + This method is used by derived classes of `GtirDataflowBuilder`, + which build a specialized subgraph for a certain GTIR builtin. + + Returns a list of SDFG nodes and the associated GT4Py data type: + tuple(node, data_type) + + The GT4Py data type is useful in the case of fields, because it provides + information on the field domain (e.g. order of dimensions, types of dimensions). + """ + raise NotImplementedError \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index c96a026655..03223b61c0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -25,7 +25,6 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( GtirTaskletCodegen, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts @@ -33,54 +32,8 @@ class GtirDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" - _sdfg: dace.SDFG - _state: dace.SDFGState - _data_types: dict[str, ts.FieldType | ts.ScalarType] - - @final - def __call__( - self, - ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - """The callable interface is used by the caller to build the dataflow graph. - - It allows to build the dataflow graph inside a given state starting - from the innermost nodes, by propagating the intermediate results - as access nodes to temporary local storage. - """ - return self._build() - - @final - def _add_local_storage( - self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] - ) -> dace.nodes.AccessNode: - """Allocates temporary storage to be used in the local scope for intermediate results.""" - name = unique_name("var") - if isinstance(data_type, ts.FieldType): - assert len(data_type.dims) == len(shape) - dtype = as_dace_type(data_type.dtype) - name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) - else: - assert len(shape) == 0 - dtype = as_dace_type(data_type) - name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) - return self._state.add_access(name) - - def _build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - """Creates the dataflow subgraph representing a given GTIR builtin. - - This method is used by derived classes of `GtirDataflowBuilder`, - which build a specialized subgraph for a certain GTIR builtin. - - Returns a list of SDFG nodes and the associated GT4Py data type: - tuple(node, data_type) - - The GT4Py data type is useful in the case of fields, because it provides - information on the field domain (e.g. order of dimensions, types of dimensions). - """ - raise NotImplementedError - - def fork(self, state: dace.SDFGState) -> "GtirDataflowBuilder": - return GtirDataflowBuilder(self._sdfg, state, self._data_types) + sdfg: dace.SDFG + data_types: dict[str, ts.FieldType | ts.ScalarType] def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: """ @@ -102,8 +55,8 @@ def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: return domain def visit_expression( - self, node: itir.Expr - ) -> tuple[dace.SDFGState, list[dace.nodes.AccessNode]]: + self, node: itir.Expr, head_state: dace.SDFGState + ) -> list[dace.nodes.AccessNode]: """ Specialized visit method for fieldview expressions. @@ -112,7 +65,7 @@ def visit_expression( TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ - expr_builder = self.visit(node) + expr_builder = self.visit(node, head_state=head_state) assert callable(expr_builder) results = expr_builder() @@ -121,12 +74,13 @@ def visit_expression( assert isinstance(node, dace.nodes.AccessNode) expressions_nodes.append(node) - # sanity check: each statement should result in a single exit state, i.e. only internal branches - sink_states = self._sdfg.sink_nodes() + # sanity check: each statement should preserve the property of single exit state (aka head state), + # i.e. eventually only introduce internal branches, and keep the same head state + sink_states = self.sdfg.sink_nodes() assert len(sink_states) == 1 - head_state = sink_states[0] + assert sink_states[0] == head_state - return head_state, expressions_nodes + return expressions_nodes def visit_symbolic(self, node: itir.Expr) -> str: """ @@ -138,32 +92,21 @@ def visit_symbolic(self, node: itir.Expr) -> str: """ return GtirTaskletCodegen().visit(node) - def visit_FunCall(self, node: itir.FunCall) -> Callable: + def visit_FunCall(self, node: itir.FunCall, head_state: dace.SDFGState) -> Callable: from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins arg_builders: list[Callable] = [] for arg in node.args: - arg_builder = self.visit(arg) + arg_builder = self.visit(arg, head_state=head_state) assert callable(arg_builder) arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtins.AsFieldOp( - self._sdfg, - self._state, - self._data_types, - node, - arg_builders, - ) + return gtir_builtins.AsFieldOp(self, head_state, node, arg_builders) elif cpm.is_call_to(node.fun, "select"): assert len(arg_builders) == 0 - return gtir_builtins.Select( - self._sdfg, - self._state, - self._data_types, - node, - ) + return gtir_builtins.Select(self, head_state, node) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -177,7 +120,7 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: """ raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") - def visit_SymRef(self, node: itir.SymRef) -> Callable: + def visit_SymRef(self, node: itir.SymRef, head_state: dace.SDFGState) -> Callable: from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins - return gtir_builtins.SymbolRef(self._sdfg, self._state, self._data_types, node) + return gtir_builtins.SymbolRef(self, head_state, node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 163c98c7d2..cb03fbf385 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -24,11 +24,15 @@ from gt4py import eve from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( + GtirDataflowBuilder as DataflowBuilder, +) +from gt4py.next.program_processors.runners.dace_fieldview.utility import ( + as_dace_type, + filter_connectivities, +) from gt4py.next.type_system import type_specifications as ts -from .gtir_dataflow_builder import GtirDataflowBuilder as DataflowBuilder -from .utility import as_dace_type, filter_connectivities - class GtirToSDFG(eve.NodeVisitor): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -150,29 +154,24 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - # the statement could eventually modify the head state by appending new states - # however, it should preserve the property of single exit state (aka head state) - head_state = self.visit(stmt, sdfg=sdfg, state=head_state) + self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() return sdfg - def visit_SetAt( - self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dace.SDFGState: + def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. The translation of `SetAt` ensures that the result is written to some global storage. """ - dataflow_builder = DataflowBuilder(sdfg, state, self._data_types) - head_state, expr_nodes = dataflow_builder.visit_expression(stmt.expr) + dataflow_builder = DataflowBuilder(sdfg, self._data_types) + expr_nodes = dataflow_builder.visit_expression(stmt.expr, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - target_builder = DataflowBuilder(sdfg, head_state, self._data_types) - head_state, target_nodes = target_builder.visit_expression(stmt.target) + target_nodes = dataflow_builder.visit_expression(stmt.target, state) assert len(expr_nodes) == len(target_nodes) domain = dataflow_builder.visit_domain(stmt.domain) @@ -192,10 +191,8 @@ def visit_SetAt( assert len(domain) == 0 subset = "0" - head_state.add_nedge( + state.add_nedge( expr_node, target_node, dace.Memlet(data=target_node.data, subset=subset), ) - - return head_state From e4042261692463ae627f3e11acadfae1c043bd3b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 3 May 2024 17:14:08 +0200 Subject: [PATCH 030/235] Codestyle review changes --- .../dace_fieldview/gtir_builtins/__init__.py | 6 +-- .../gtir_builtin_field_operator.py | 15 ++++--- .../gtir_builtins/gtir_builtin_select.py | 16 ++++---- .../gtir_builtins/gtir_builtin_symbol_ref.py | 8 ++-- .../gtir_builtins/gtir_builtin_translator.py | 8 ++-- .../dace_fieldview/gtir_dataflow_builder.py | 6 +-- .../dace_fieldview/gtir_tasklet_codegen.py | 15 +++---- .../runners/dace_fieldview/gtir_to_sdfg.py | 41 +++++++++---------- .../runners/dace_fieldview/utility.py | 24 ++++++----- .../runners_tests/test_dace_fieldview.py | 2 +- 10 files changed, 67 insertions(+), 74 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py index 293beec8f8..c99b418eae 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py @@ -13,13 +13,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( - GtirBuiltinAsFieldOp as AsFieldOp, + GTIRBuiltinAsFieldOp as AsFieldOp, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_select import ( - GtirBuiltinSelect as Select, + GTIRBuiltinSelect as Select, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_symbol_ref import ( - GtirBuiltinSymbolRef as SymbolRef, + GTIRBuiltinSymbolRef as SymbolRef, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 6e19ea68ae..6aa3ed9832 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -21,18 +21,18 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GtirBuiltinTranslator, + GTIRBuiltinTranslator, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GtirDataflowBuilder, + GTIRDataflowBuilder, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GtirTaskletCodegen, + GTIRTaskletCodegen, ) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinAsFieldOp(GtirBuiltinTranslator): +class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" stencil_expr: itir.Lambda @@ -42,7 +42,7 @@ class GtirBuiltinAsFieldOp(GtirBuiltinTranslator): def __init__( self, - dataflow_builder: GtirDataflowBuilder, + dataflow_builder: GTIRDataflowBuilder, state: dace.SDFGState, node: itir.FunCall, stencil_args: list[Callable], @@ -74,7 +74,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # the field operator as a mapped tasklet, which will range over the field domain output_connector = "__out" tlet_code = "{var} = {code}".format( - var=output_connector, code=GtirTaskletCodegen().visit(self.stencil_expr.expr) + var=output_connector, code=GTIRTaskletCodegen().visit(self.stencil_expr.expr) ) # allocate local temporary storage for the result field @@ -90,8 +90,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: input_nodes: dict[str, dace.nodes.AccessNode] = {} input_memlets: dict[str, dace.Memlet] = {} - assert len(self.stencil_args) == len(self.stencil_expr.params) - for arg, param in zip(self.stencil_args, self.stencil_expr.params): + for arg, param in zip(self.stencil_args, self.stencil_expr.params, strict=True): arg_nodes = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index a9e0b8da9e..086e85ab45 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -20,15 +20,15 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GtirBuiltinTranslator, + GTIRBuiltinTranslator, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GtirDataflowBuilder, + GTIRDataflowBuilder, ) from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSelect(GtirBuiltinTranslator): +class GTIRBuiltinSelect(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" true_br_builder: Callable @@ -36,7 +36,7 @@ class GtirBuiltinSelect(GtirBuiltinTranslator): def __init__( self, - dataflow_builder: GtirDataflowBuilder, + dataflow_builder: GTIRDataflowBuilder, state: dace.SDFGState, node: itir.FunCall, ): @@ -81,10 +81,6 @@ def __init__( self.false_br_builder = dataflow_builder.visit(false_expr, head_state=false_state) def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - true_br_args = self.true_br_builder() - false_br_args = self.false_br_builder() - assert len(true_br_args) == len(false_br_args) - # retrieve true/false states as predecessors of head state branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) assert len(branch_states) == 2 @@ -93,9 +89,11 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: else: false_state, true_state = branch_states + true_br_args = self.true_br_builder() + false_br_args = self.false_br_builder() output_nodes = [] - for true_br, false_br in zip(true_br_args, false_br_args): + for true_br, false_br in zip(true_br_args, false_br_args, strict=True): true_br_node, true_br_type = true_br assert isinstance(true_br_node, dace.nodes.AccessNode) false_br_node, false_br_type = false_br diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index 2a94b8dc87..3562f1070f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -19,16 +19,16 @@ from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GtirBuiltinTranslator, + GTIRBuiltinTranslator, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GtirDataflowBuilder, + GTIRDataflowBuilder, ) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts -class GtirBuiltinSymbolRef(GtirBuiltinTranslator): +class GTIRBuiltinSymbolRef(GTIRBuiltinTranslator): """Generates the dataflow subgraph for a `itir.SymRef` node.""" sym_name: str @@ -36,7 +36,7 @@ class GtirBuiltinSymbolRef(GtirBuiltinTranslator): def __init__( self, - dataflow_builder: GtirDataflowBuilder, + dataflow_builder: GTIRDataflowBuilder, state: dace.SDFGState, node: itir.SymRef, ): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 4104833214..90a9a5dff0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -16,15 +16,16 @@ from abc import abstractmethod from dataclasses import dataclass from typing import final -from gt4py import eve import dace + +from gt4py import eve from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) -class GtirBuiltinTranslator(eve.NodeVisitor): +class GTIRBuiltinTranslator(eve.NodeVisitor): head_state: dace.SDFGState sdfg: dace.SDFG @@ -60,7 +61,7 @@ def add_local_storage( def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: """Creates the dataflow subgraph representing a given GTIR builtin. - This method is used by derived classes of `GtirDataflowBuilder`, + This method is used by derived classes of `GTIRBuiltinTranslator`, which build a specialized subgraph for a certain GTIR builtin. Returns a list of SDFG nodes and the associated GT4Py data type: @@ -69,4 +70,3 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: The GT4Py data type is useful in the case of fields, because it provides information on the field domain (e.g. order of dimensions, types of dimensions). """ - raise NotImplementedError \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py index 03223b61c0..07805dccc2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py @@ -23,13 +23,13 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GtirTaskletCodegen, + GTIRTaskletCodegen, ) from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) -class GtirDataflowBuilder(eve.NodeVisitor): +class GTIRDataflowBuilder(eve.NodeVisitor): """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" sdfg: dace.SDFG @@ -90,7 +90,7 @@ def visit_symbolic(self, node: itir.Expr) -> str: TODO: should we return a list of code strings in case of tuple returns, one for each output value? """ - return GtirTaskletCodegen().visit(node) + return GTIRTaskletCodegen().visit(node) def visit_FunCall(self, node: itir.FunCall, head_state: dace.SDFGState) -> Callable: from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py index cb7ae68fcc..7473ffdb20 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from typing import final import numpy as np from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -78,7 +78,7 @@ @dataclass(frozen=True) -class GtirTaskletCodegen(codegen.TemplatedGenerator): +class GTIRTaskletCodegen(codegen.TemplatedGenerator): """ Stateless class to visit pure tasklet expressions. @@ -86,6 +86,9 @@ class GtirTaskletCodegen(codegen.TemplatedGenerator): the Python code inside a tasklet node. """ + Literal = as_fmt("{value}") + SymRef = as_fmt("{id}") + def _visit_deref(self, node: itir.FunCall) -> str: assert len(node.args) == 1 if isinstance(node.args[0], itir.SymRef): @@ -108,11 +111,3 @@ def visit_FunCall(self, node: itir.FunCall) -> str: else: raise NotImplementedError(f"'{builtin_name}' not implemented.") raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - - @final - def visit_Literal(self, node: itir.Literal) -> str: - return node.value - - @final - def visit_SymRef(self, node: itir.SymRef) -> str: - return str(node.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index cb03fbf385..fdfd58d0d2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -25,7 +25,7 @@ from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GtirDataflowBuilder as DataflowBuilder, + GTIRDataflowBuilder as DataflowBuilder, ) from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, @@ -34,7 +34,7 @@ from gt4py.next.type_system import type_specifications as ts -class GtirToSDFG(eve.NodeVisitor): +class GTIRToSDFG(eve.NodeVisitor): """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation @@ -47,13 +47,13 @@ class GtirToSDFG(eve.NodeVisitor): from where to continue building the SDFG. """ - _param_types: list[ts.TypeSpec] + _param_types: list[ts.DataType] _data_types: dict[str, ts.FieldType | ts.ScalarType] _offset_providers: Mapping[str, Any] def __init__( self, - param_types: list[ts.TypeSpec], + param_types: list[ts.DataType], offset_providers: dict[str, Connectivity | Dimension], ): self._param_types = param_types @@ -78,36 +78,38 @@ def _make_array_shape_and_strides( neighbor_tables = filter_connectivities(self._offset_providers) shape = [ ( - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain neighbor_tables[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain + # we reuse the same symbol for field size passed as scalar argument to the gt4py program else dace.symbol(f"__{name}_size_{i}", dtype) ) for i, dim in enumerate(dims) ] - strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i, _ in enumerate(dims)] + strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i in range(len(dims))] return shape, strides - def _add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec) -> None: + def _add_storage(self, sdfg: dace.SDFG, name: str, data_type: ts.DataType) -> None: """ Add external storage (aka non-transient) for data containers passed as arguments to the SDFG. For fields, it allocates dace arrays, while scalars are stored as SDFG symbols. """ - assert isinstance(type_, (ts.FieldType, ts.ScalarType)) - self._data_types[name] = type_ - - if isinstance(type_, ts.FieldType): - dtype = as_dace_type(type_.dtype) + if isinstance(data_type, ts.FieldType): + dtype = as_dace_type(data_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, type_.dims) + sym_shape, sym_strides = self._make_array_shape_and_strides(name, data_type.dims) sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) - else: - dtype = as_dace_type(type_) + elif isinstance(data_type, ts.ScalarType): + dtype = as_dace_type(data_type) # scalar arguments passed to the program are represented as symbols in DaCe SDFG sdfg.add_symbol(name, dtype) + else: + raise RuntimeError(f"Data type '{type(data_type)}' not supported.") + + # TODO: unclear why mypy complains about incompatible types + assert isinstance(data_type, (ts.FieldType, ts.ScalarType)) + self._data_types[name] = data_type def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, str]: """ @@ -116,7 +118,6 @@ def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, Assume all temporaries to be fields, therefore represented as dace arrays. """ raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") - return {} def visit_Program(self, node: itir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. @@ -147,8 +148,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: head_state = entry_state # add non-transient arrays and/or SDFG symbols for the program arguments - assert len(node.params) == len(self._param_types) - for param, type_ in zip(node.params, self._param_types): + for param, type_ in zip(node.params, self._param_types, strict=True): self._add_storage(sdfg, str(param.id), type_) # visit one statement at a time and expand the SDFG from the current head state @@ -172,13 +172,12 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field target_nodes = dataflow_builder.visit_expression(stmt.target, state) - assert len(expr_nodes) == len(target_nodes) domain = dataflow_builder.visit_domain(stmt.domain) # convert domain to dictionary to ease access to dimension boundaries domain_map = {dim: (lb, ub) for dim, lb, ub in domain} - for expr_node, target_node in zip(expr_nodes, target_nodes): + for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] assert not target_array.transient target_field_type = self._data_types[target_node.data] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 6799d2bb2d..ac70241098 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -22,17 +22,19 @@ def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" - if type_.kind == ts.ScalarKind.BOOL: - return dace.bool_ - elif type_.kind == ts.ScalarKind.INT32: - return dace.int32 - elif type_.kind == ts.ScalarKind.INT64: - return dace.int64 - elif type_.kind == ts.ScalarKind.FLOAT32: - return dace.float32 - elif type_.kind == ts.ScalarKind.FLOAT64: - return dace.float64 - raise ValueError(f"Scalar type '{type_}' not supported.") + match type_.kind: + case ts.ScalarKind.BOOL: + return dace.bool_ + case ts.ScalarKind.INT32: + return dace.int32 + case ts.ScalarKind.INT64: + return dace.int64 + case ts.ScalarKind.FLOAT32: + return dace.float32 + case ts.ScalarKind.FLOAT64: + return dace.float64 + case _: + raise ValueError(f"Scalar type '{type_}' not supported.") def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index f15132a608..7113ed054c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -22,7 +22,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( - GtirToSDFG as FieldviewGtirToSDFG, + GTIRToSDFG as FieldviewGtirToSDFG, ) from gt4py.next.type_system import type_specifications as ts From bb0dfac5632b5064e7a5fcc5f7d7ebbe6f7b7282 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 6 May 2024 09:43:09 +0200 Subject: [PATCH 031/235] Remove circular dependency for builtin translators --- .../dace_fieldview/gtir_builtins/__init__.py | 4 + .../gtir_builtins/gtir_builtin_domain.py | 55 ++++++++ .../gtir_builtin_field_operator.py | 17 +-- .../gtir_builtins/gtir_builtin_select.py | 18 +-- .../gtir_builtins/gtir_builtin_symbol_ref.py | 15 +-- .../gtir_builtins/gtir_builtin_translator.py | 91 ++++++++++++- .../dace_fieldview/gtir_dataflow_builder.py | 126 ------------------ .../runners/dace_fieldview/gtir_to_sdfg.py | 80 +++++++++-- 8 files changed, 241 insertions(+), 165 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py index c99b418eae..8376d7357e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py @@ -12,6 +12,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_domain import ( + GTIRBuiltinDomain as FieldDomain, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( GTIRBuiltinAsFieldOp as AsFieldOp, ) @@ -26,6 +29,7 @@ # export short names of translation classes for GTIR builtin functions __all__ = [ "AsFieldOp", + "FieldDomain", "Select", "SymbolRef", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py new file mode 100644 index 0000000000..063137a8ca --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py @@ -0,0 +1,55 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from dataclasses import dataclass +from typing import final + +import dace +import numpy as np + +from gt4py import eve +from gt4py.next.common import Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import GTIRBuiltinTranslator +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name +from gt4py.next.type_system import type_specifications as ts + + +@dataclass(frozen=True) +class GTIRBuiltinDomain(GTIRBuiltinTranslator): + + def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + raise NotImplementedError + + def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: + """ + Specialized visit method for domain expressions. + + Returns a list of dimensions and the corresponding range. + """ + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + + domain = [] + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, itir.AxisLiteral) + dim = Dimension(axis.value) + bounds = [self.visit(arg) for arg in named_range.args[1:3]] + domain.append((dim, bounds[0], bounds[1])) + + return domain diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 6aa3ed9832..01280562e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -20,15 +20,12 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_domain import ( + GTIRBuiltinDomain, +) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GTIRDataflowBuilder, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GTIRTaskletCodegen, -) from gt4py.next.type_system import type_specifications as ts @@ -42,12 +39,12 @@ class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): def __init__( self, - dataflow_builder: GTIRDataflowBuilder, + sdfg: dace.SDFG, state: dace.SDFGState, node: itir.FunCall, stencil_args: list[Callable], ): - super().__init__(state, dataflow_builder.sdfg) + super().__init__(sdfg, state) assert cpm.is_call_to(node.fun, "as_fieldop") assert len(node.fun.args) == 2 @@ -58,7 +55,7 @@ def __init__( assert isinstance(domain_expr, itir.FunCall) # visit field domain - domain = dataflow_builder.visit_domain(domain_expr) + domain = GTIRBuiltinDomain(sdfg, state).visit_domain(domain_expr) # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type @@ -74,7 +71,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # the field operator as a mapped tasklet, which will range over the field domain output_connector = "__out" tlet_code = "{var} = {code}".format( - var=output_connector, code=GTIRTaskletCodegen().visit(self.stencil_expr.expr) + var=output_connector, code=self.visit(self.stencil_expr.expr) ) # allocate local temporary storage for the result field diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index 086e85ab45..058ce568fd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -17,14 +17,12 @@ import dace +from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GTIRDataflowBuilder, -) from gt4py.next.type_system import type_specifications as ts @@ -36,19 +34,19 @@ class GTIRBuiltinSelect(GTIRBuiltinTranslator): def __init__( self, - dataflow_builder: GTIRDataflowBuilder, + sdfg: dace.SDFG, state: dace.SDFGState, + dataflow_builder: eve.NodeVisitor, node: itir.FunCall, ): - super().__init__(state, dataflow_builder.sdfg) - sdfg = dataflow_builder.sdfg + super().__init__(sdfg, state) assert cpm.is_call_to(node.fun, "select") assert len(node.fun.args) == 3 cond_expr, true_expr, false_expr = node.fun.args # expect condition as first argument - cond = dataflow_builder.visit_symbolic(cond_expr) + cond = self.visit(cond_expr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: @@ -72,13 +70,15 @@ def __init__( true_state = sdfg.add_state(state.label + "_true_branch") sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) sdfg.add_edge(true_state, state, dace.InterstateEdge()) - self.true_br_builder = dataflow_builder.visit(true_expr, head_state=true_state) + self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond}")) sdfg.add_edge(false_state, state, dace.InterstateEdge()) - self.false_br_builder = dataflow_builder.visit(false_expr, head_state=false_state) + self.false_br_builder = dataflow_builder.visit( + false_expr, sdfg=sdfg, head_state=false_state + ) def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # retrieve true/false states as predecessors of head state diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index 3562f1070f..75b7f2f106 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -17,13 +17,9 @@ import dace -from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GTIRDataflowBuilder, -) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts @@ -36,15 +32,14 @@ class GTIRBuiltinSymbolRef(GTIRBuiltinTranslator): def __init__( self, - dataflow_builder: GTIRDataflowBuilder, + sdfg: dace.SDFG, state: dace.SDFGState, - node: itir.SymRef, + sym_name: str, + sym_type: ts.FieldType | ts.ScalarType, ): - super().__init__(state, dataflow_builder.sdfg) - sym_name = str(node.id) - assert sym_name in dataflow_builder.data_types + super().__init__(sdfg, state) self.sym_name = sym_name - self.sym_type = dataflow_builder.data_types[sym_name] + self.sym_type = sym_type def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: """Returns, if present, the access node in current state for the data symbol.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 90a9a5dff0..baddcf386d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -18,16 +18,74 @@ from typing import final import dace +import numpy as np from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts +_MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + @dataclass(frozen=True) class GTIRBuiltinTranslator(eve.NodeVisitor): - head_state: dace.SDFGState sdfg: dace.SDFG + head_state: dace.SDFGState @final def __call__( @@ -70,3 +128,34 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: The GT4Py data type is useful in the case of fields, because it provides information on the field domain (e.g. order of dimensions, types of dimensions). """ + + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = [self.visit(arg) for arg in node.args] + return fmt.format(*args) + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + @final + def visit_Literal(self, node: itir.Literal) -> str: + return node.value + + @final + def visit_SymRef(self, node: itir.SymRef) -> str: + return str(node.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py deleted file mode 100644 index 07805dccc2..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow_builder.py +++ /dev/null @@ -1,126 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -from dataclasses import dataclass -from typing import Any, Callable, final - -import dace - -import gt4py.eve as eve -from gt4py.next.common import Dimension -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import ( - GTIRTaskletCodegen, -) -from gt4py.next.type_system import type_specifications as ts - - -@dataclass(frozen=True) -class GTIRDataflowBuilder(eve.NodeVisitor): - """Translates a GTIR `ir.Stmt` node to a dataflow graph.""" - - sdfg: dace.SDFG - data_types: dict[str, ts.FieldType | ts.ScalarType] - - def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: - """ - Specialized visit method for domain expressions. - - Returns a list of dimensions and the corresponding range. - """ - domain = [] - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, itir.AxisLiteral) - dim = Dimension(axis.value) - bounds = [self.visit_symbolic(arg) for arg in named_range.args[1:3]] - domain.append((dim, bounds[0], bounds[1])) - - return domain - - def visit_expression( - self, node: itir.Expr, head_state: dace.SDFGState - ) -> list[dace.nodes.AccessNode]: - """ - Specialized visit method for fieldview expressions. - - This method represents the entry point to visit 'Stmt' expressions. - As such, it must preserve the property of single exit state in the SDFG. - - TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? - """ - expr_builder = self.visit(node, head_state=head_state) - assert callable(expr_builder) - results = expr_builder() - - expressions_nodes = [] - for node, _ in results: - assert isinstance(node, dace.nodes.AccessNode) - expressions_nodes.append(node) - - # sanity check: each statement should preserve the property of single exit state (aka head state), - # i.e. eventually only introduce internal branches, and keep the same head state - sink_states = self.sdfg.sink_nodes() - assert len(sink_states) == 1 - assert sink_states[0] == head_state - - return expressions_nodes - - def visit_symbolic(self, node: itir.Expr) -> str: - """ - Specialized visit method for pure stencil expressions. - - Returns a string represnting the Python code to be used as tasklet body. - TODO: should we return a list of code strings in case of tuple returns, - one for each output value? - """ - return GTIRTaskletCodegen().visit(node) - - def visit_FunCall(self, node: itir.FunCall, head_state: dace.SDFGState) -> Callable: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins - - arg_builders: list[Callable] = [] - for arg in node.args: - arg_builder = self.visit(arg, head_state=head_state) - assert callable(arg_builder) - arg_builders.append(arg_builder) - - if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtins.AsFieldOp(self, head_state, node, arg_builders) - - elif cpm.is_call_to(node.fun, "select"): - assert len(arg_builders) == 0 - return gtir_builtins.Select(self, head_state, node) - - else: - raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - - @final - def visit_Lambda(self, node: itir.Lambda) -> Any: - """ - This visitor class should never encounter `itir.Lambda` expressions - because a lambda represents a stencil, which translates from iterator to values. - In fieldview, lambdas should only be arguments to field operators (`as_field_op`). - """ - raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") - - def visit_SymRef(self, node: itir.SymRef, head_state: dace.SDFGState) -> Callable: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins - - return gtir_builtins.SymbolRef(self, head_state, node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index fdfd58d0d2..5a8b40f27f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -17,16 +17,15 @@ Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Mapping, Sequence +from typing import Any, Callable, Mapping, Sequence, final import dace -from gt4py import eve +import gt4py.eve as eve from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dataflow_builder import ( - GTIRDataflowBuilder as DataflowBuilder, -) +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, filter_connectivities, @@ -119,6 +118,34 @@ def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, """ raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") + def _visit_expression( + self, node: itir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState + ) -> list[dace.nodes.AccessNode]: + """ + Specialized visit method for fieldview expressions. + + This method represents the entry point to visit 'Stmt' expressions. + As such, it must preserve the property of single exit state in the SDFG. + + TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? + """ + expr_builder = self.visit(node, sdfg=sdfg, head_state=head_state) + assert callable(expr_builder) + results = expr_builder() + + expressions_nodes = [] + for node, _ in results: + assert isinstance(node, dace.nodes.AccessNode) + expressions_nodes.append(node) + + # sanity check: each statement should preserve the property of single exit state (aka head state), + # i.e. eventually only introduce internal branches, and keep the same head state + sink_states = sdfg.sink_nodes() + assert len(sink_states) == 1 + assert sink_states[0] == head_state + + return expressions_nodes + def visit_Program(self, node: itir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. @@ -166,14 +193,13 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) The translation of `SetAt` ensures that the result is written to some global storage. """ - dataflow_builder = DataflowBuilder(sdfg, self._data_types) - expr_nodes = dataflow_builder.visit_expression(stmt.expr, state) + expr_nodes = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field - target_nodes = dataflow_builder.visit_expression(stmt.target, state) + target_nodes = self._visit_expression(stmt.target, sdfg, state) - domain = dataflow_builder.visit_domain(stmt.domain) + domain = gtir_builtins.FieldDomain(sdfg, state).visit_domain(stmt.domain) # convert domain to dictionary to ease access to dimension boundaries domain_map = {dim: (lb, ub) for dim, lb, ub in domain} @@ -195,3 +221,39 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) target_node, dace.Memlet(data=target_node.data, subset=subset), ) + + def visit_FunCall( + self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState + ) -> Callable: + arg_builders: list[Callable] = [] + for arg in node.args: + arg_builder = self.visit(arg, sdfg=sdfg, head_state=head_state) + assert callable(arg_builder) + arg_builders.append(arg_builder) + + if cpm.is_call_to(node.fun, "as_fieldop"): + return gtir_builtins.AsFieldOp(sdfg, head_state, node, arg_builders) + + elif cpm.is_call_to(node.fun, "select"): + assert len(arg_builders) == 0 + return gtir_builtins.Select(sdfg, head_state, self, node) + + else: + raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + + @final + def visit_Lambda(self, node: itir.Lambda) -> Any: + """ + This visitor class should never encounter `itir.Lambda` expressions + because a lambda represents a stencil, which translates from iterator to values. + In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ + raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") + + def visit_SymRef( + self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState + ) -> Callable: + sym_name = str(node.id) + assert sym_name in self._data_types + sym_type = self._data_types[sym_name] + return gtir_builtins.SymbolRef(sdfg, head_state, sym_name, sym_type) From 412cd5da101373cb777c19bfe6b9008d21075438 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 6 May 2024 09:45:27 +0200 Subject: [PATCH 032/235] Fix formatting --- .../dace_fieldview/gtir_builtins/gtir_builtin_domain.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py index 063137a8ca..1a745dff53 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py @@ -14,23 +14,20 @@ from dataclasses import dataclass -from typing import final import dace -import numpy as np -from gt4py import eve from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import GTIRBuiltinTranslator -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( + GTIRBuiltinTranslator, +) from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) class GTIRBuiltinDomain(GTIRBuiltinTranslator): - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: raise NotImplementedError From 651de5c333eaf0da3888d26891c3eab7632e5e12 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 6 May 2024 09:52:59 +0200 Subject: [PATCH 033/235] Minor edit --- .../gtir_builtins/gtir_builtin_translator.py | 16 +++++-------- .../runners/dace_fieldview/gtir_to_sdfg.py | 24 +++++++++---------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index baddcf386d..444670d4ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -20,7 +20,8 @@ import dace import numpy as np -from gt4py import eve +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name @@ -83,10 +84,13 @@ @dataclass(frozen=True) -class GTIRBuiltinTranslator(eve.NodeVisitor): +class GTIRBuiltinTranslator(codegen.TemplatedGenerator): sdfg: dace.SDFG head_state: dace.SDFGState + Literal = as_fmt("{value}") + SymRef = as_fmt("{id}") + @final def __call__( self, @@ -151,11 +155,3 @@ def visit_FunCall(self, node: itir.FunCall) -> str: else: raise NotImplementedError(f"'{builtin_name}' not implemented.") raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - - @final - def visit_Literal(self, node: itir.Literal) -> str: - return node.value - - @final - def visit_SymRef(self, node: itir.SymRef) -> str: - return str(node.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 5a8b40f27f..1f18c5c3fd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -46,18 +46,18 @@ class GTIRToSDFG(eve.NodeVisitor): from where to continue building the SDFG. """ - _param_types: list[ts.DataType] - _data_types: dict[str, ts.FieldType | ts.ScalarType] - _offset_providers: Mapping[str, Any] + data_types: dict[str, ts.FieldType | ts.ScalarType] + param_types: list[ts.DataType] + offset_providers: Mapping[str, Any] def __init__( self, param_types: list[ts.DataType], offset_providers: dict[str, Connectivity | Dimension], ): - self._param_types = param_types - self._data_types = {} - self._offset_providers = offset_providers + self.data_types = {} + self.param_types = param_types + self.offset_providers = offset_providers def _make_array_shape_and_strides( self, name: str, dims: Sequence[Dimension] @@ -74,7 +74,7 @@ def _make_array_shape_and_strides( The output tuple fields are arrays of dace symbolic expressions. """ dtype = dace.int32 - neighbor_tables = filter_connectivities(self._offset_providers) + neighbor_tables = filter_connectivities(self.offset_providers) shape = [ ( neighbor_tables[dim.value].max_neighbors @@ -108,7 +108,7 @@ def _add_storage(self, sdfg: dace.SDFG, name: str, data_type: ts.DataType) -> No # TODO: unclear why mypy complains about incompatible types assert isinstance(data_type, (ts.FieldType, ts.ScalarType)) - self._data_types[name] = data_type + self.data_types[name] = data_type def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, str]: """ @@ -175,7 +175,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: head_state = entry_state # add non-transient arrays and/or SDFG symbols for the program arguments - for param, type_ in zip(node.params, self._param_types, strict=True): + for param, type_ in zip(node.params, self.param_types, strict=True): self._add_storage(sdfg, str(param.id), type_) # visit one statement at a time and expand the SDFG from the current head state @@ -206,7 +206,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] assert not target_array.transient - target_field_type = self._data_types[target_node.data] + target_field_type = self.data_types[target_node.data] if isinstance(target_field_type, ts.FieldType): subset = ",".join( @@ -254,6 +254,6 @@ def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> Callable: sym_name = str(node.id) - assert sym_name in self._data_types - sym_type = self._data_types[sym_name] + assert sym_name in self.data_types + sym_type = self.data_types[sym_name] return gtir_builtins.SymbolRef(sdfg, head_state, sym_name, sym_type) From dcf3eab0cc4b0271a3bb2866c0b893c37a62361c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 6 May 2024 15:17:06 +0200 Subject: [PATCH 034/235] Add support to translate each builtin call to a tasklet node --- .../gtir_builtins/gtir_builtin_domain.py | 15 ++- .../gtir_builtin_field_operator.py | 57 ++++----- .../gtir_builtins/gtir_builtin_select.py | 6 +- .../gtir_builtins/gtir_builtin_symbol_ref.py | 5 +- .../gtir_builtins/gtir_builtin_translator.py | 103 +++++++++++++--- .../dace_fieldview/gtir_tasklet_codegen.py | 113 ------------------ .../runners/dace_fieldview/gtir_to_sdfg.py | 22 +--- .../runners/dace_fieldview/utility.py | 3 +- .../runners_tests/test_dace_fieldview.py | 12 +- 9 files changed, 148 insertions(+), 188 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py index 1a745dff53..4e3f39594b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py @@ -13,8 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass - import dace from gt4py.next.common import Dimension @@ -22,11 +20,12 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + LiteralExpr, + SymbolExpr, ) from gt4py.next.type_system import type_specifications as ts -@dataclass(frozen=True) class GTIRBuiltinDomain(GTIRBuiltinTranslator): def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: raise NotImplementedError @@ -46,7 +45,15 @@ def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: axis = named_range.args[0] assert isinstance(axis, itir.AxisLiteral) dim = Dimension(axis.value) - bounds = [self.visit(arg) for arg in named_range.args[1:3]] + bounds = [] + for arg in named_range.args[1:3]: + bound = self.visit(arg) + if isinstance(bound, SymbolExpr): + assert bound.data in self.sdfg.symbols + bounds.append(bound.data) + else: + assert isinstance(bound, LiteralExpr) + bounds.append(bound.value) domain.append((dim, bounds[0], bounds[1])) return domain diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 01280562e5..d285724c1f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -25,7 +25,9 @@ ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + ValueExpr, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import unique_name from gt4py.next.type_system import type_specifications as ts @@ -67,12 +69,12 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + assert len(self.input_connections) == 0 + # generate a tasklet node implementing the stencil function and represent # the field operator as a mapped tasklet, which will range over the field domain - output_connector = "__out" - tlet_code = "{var} = {code}".format( - var=output_connector, code=self.visit(self.stencil_expr.expr) - ) + output_expr = self.visit(self.stencil_expr.expr) + assert isinstance(output_expr, ValueExpr) # allocate local temporary storage for the result field field_shape = [ @@ -82,51 +84,50 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ] field_node = self.add_local_storage(self.field_type, field_shape) - # create map range corresponding to the field operator domain - map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} - - input_nodes: dict[str, dace.nodes.AccessNode] = {} + data_nodes: dict[str, dace.nodes.AccessNode] = {} input_memlets: dict[str, dace.Memlet] = {} for arg, param in zip(self.stencil_args, self.stencil_expr.params, strict=True): arg_nodes = arg() assert len(arg_nodes) == 1 arg_node, arg_type = arg_nodes[0] - connector = str(param.id) + data = str(param.id) # require (for now) all input nodes to be data access nodes assert isinstance(arg_node, dace.nodes.AccessNode) - input_nodes[arg_node.data] = arg_node + data_nodes[data] = arg_node if isinstance(arg_type, ts.FieldType): # support either single element access (general case) or full array shape is_scalar = all(dim in self.field_domain for dim in arg_type.dims) if is_scalar: subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) - input_memlets[connector] = dace.Memlet( - data=arg_node.data, subset=subset, volume=1 - ) + input_memlets[data] = dace.Memlet(data=arg_node.data, subset=subset, volume=1) else: memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) # set volume to 1 because the stencil function always performs single element access # TODO: check validity of this assumption memlet.volume = 1 - input_memlets[connector] = memlet + input_memlets[data] = memlet else: - input_memlets[connector] = dace.Memlet(data=arg_node.data, subset="0") + input_memlets[data] = dace.Memlet(data=arg_node.data, subset="0") # assume tasklet with single output output_index = ",".join(f"i_{dim.value}" for dim in self.field_type.dims) - output_memlets = {output_connector: dace.Memlet(data=field_node.data, subset=output_index)} - output_nodes = {field_node.data: field_node} - - # create a tasklet inside a parallel-map scope - self.head_state.add_mapped_tasklet( - "tasklet", - map_ranges, - input_memlets, - tlet_code, - output_memlets, - input_nodes=input_nodes, - output_nodes=output_nodes, - external_edges=True, + output_memlet = dace.Memlet(data=field_node.data, subset=output_index) + + # create map range corresponding to the field operator domain + map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} + me, mx = self.head_state.add_map(unique_name("map"), map_ranges) + + for (input_node, input_connector), input_param in self.input_connections: + assert input_param in data_nodes + self.head_state.add_memlet_path( + data_nodes[input_param], + me, + input_node, + dst_conn=input_connector, + memlet=input_memlets[input_param], + ) + self.head_state.add_memlet_path( + output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet ) return [(field_node, self.field_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index 058ce568fd..1c8a31ad36 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -22,6 +22,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + SymbolExpr, ) from gt4py.next.type_system import type_specifications as ts @@ -47,6 +48,7 @@ def __init__( # expect condition as first argument cond = self.visit(cond_expr) + assert isinstance(cond, SymbolExpr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: @@ -68,13 +70,13 @@ def __init__( # expect true branch as second argument true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) + sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond.data)) sdfg.add_edge(true_state, state, dace.InterstateEdge()) self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond}")) + sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond.data}")) sdfg.add_edge(false_state, state, dace.InterstateEdge()) self.false_br_builder = dataflow_builder.visit( false_expr, sdfg=sdfg, head_state=false_state diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index 75b7f2f106..35ca173369 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -20,7 +20,6 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type from gt4py.next.type_system import type_specifications as ts @@ -73,9 +72,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: {"__out"}, f"__out = {self.sym_name}", ) - name = f"{self.head_state.label}_var" - dtype = as_dace_type(self.sym_type) - sym_node = self.head_state.add_scalar(name, dtype, find_new_name=True, transient=True) + sym_node = self.add_local_storage(self.sym_type, shape=[]) self.head_state.add_edge( tasklet_node, "__out", sym_node, None, dace.Memlet(data=sym_node.data, subset="0") ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 444670d4ee..81b9e4f03e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -15,13 +15,12 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import final +from typing import Any, TypeAlias, final import dace import numpy as np -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt +from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name @@ -84,12 +83,32 @@ @dataclass(frozen=True) -class GTIRBuiltinTranslator(codegen.TemplatedGenerator): +class LiteralExpr: + value: dace.symbolic.SymbolicType + + +@dataclass(frozen=True) +class SymbolExpr: + data: str + + +@dataclass(frozen=True) +class ValueExpr: + node: dace.nodes.Tasklet + connector: str + + +class GTIRBuiltinTranslator(eve.NodeVisitor): + TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] + sdfg: dace.SDFG head_state: dace.SDFGState + input_connections: list[TaskletConnector] - Literal = as_fmt("{value}") - SymRef = as_fmt("{id}") + def __init__(self, sdfg: dace.SDFG, head_state: dace.SDFGState): + self.sdfg = sdfg + self.head_state = head_state + self.input_connections = [] @final def __call__( @@ -133,25 +152,79 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: information on the field domain (e.g. order of dimensions, types of dimensions). """ - def _visit_deref(self, node: itir.FunCall) -> str: + def _visit_deref(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: assert len(node.args) == 1 if isinstance(node.args[0], itir.SymRef): return self.visit(node.args[0]) raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = [self.visit(arg) for arg in node.args] - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: + @final + def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + # create a tasklet node implementing the builtin function + inputs = {} + inp_data = set() + inp_nodes = set() + node_internals = [] + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, LiteralExpr): + node_internals.append(arg_expr.value) + else: + connector = f"__inp_{i}" + if isinstance(arg_expr, ValueExpr): + inputs[connector] = arg_expr + else: + assert isinstance(arg_expr, SymbolExpr) + inp_data.add((connector, arg_expr.data)) + inp_nodes.add(connector) + node_internals.append(connector) + builtin_name = str(node.fun.id) if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) + fmt = _MATH_BUILTINS_MAPPING[builtin_name] + code = fmt.format(*node_internals) else: raise NotImplementedError(f"'{builtin_name}' not implemented.") + + out_connector = "__out" + tasklet_node = self.head_state.add_tasklet( + unique_name("tasklet"), + inp_nodes, + {out_connector}, + "{} = {}".format(out_connector, code), + ) + for input_conn, inp_expr in inputs.items(): + self.head_state.add_edge( + inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() + ) + self.input_connections.extend( + ((tasklet_node, connector), data) for connector, data in inp_data + ) + return ValueExpr(tasklet_node, out_connector) + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + @final + def visit_Lambda(self, node: itir.Lambda) -> Any: + """ + This visitor class should never encounter `itir.Lambda` expressions + because a lambda represents a stencil, which operates from iterator to values. + In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ + raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GTIRBuiltinTranslator'.") + + @final + def visit_Literal(self, node: itir.Literal) -> LiteralExpr: + return LiteralExpr(node.value) + + @final + def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: + """ + Symbol references are mapped to tasklet connectors that access some kind of data. + """ + sym_name = str(node.id) + return SymbolExpr(sym_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py deleted file mode 100644 index 7473ffdb20..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_tasklet_codegen.py +++ /dev/null @@ -1,113 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from dataclasses import dataclass - -import numpy as np - -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -@dataclass(frozen=True) -class GTIRTaskletCodegen(codegen.TemplatedGenerator): - """ - Stateless class to visit pure tasklet expressions. - - This visitor class is responsible for building the string representing - the Python code inside a tasklet node. - """ - - Literal = as_fmt("{value}") - SymRef = as_fmt("{id}") - - def _visit_deref(self, node: itir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - elif isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 1f18c5c3fd..90f10ff8ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -12,12 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later """ -Class to lower GTIR to a DaCe SDFG. +Class to lower GTIR to DaCe SDFG. Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Callable, Mapping, Sequence, final +from typing import Any, Callable, Mapping, Sequence import dace @@ -68,10 +68,8 @@ def _make_array_shape_and_strides( For local dimensions, the size is known at compile-time and therefore the corresponding array shape dimension is set to an integer literal value. - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. + Returns: + Two list of symbols, one for the shape and another for the strides of the array. """ dtype = dace.int32 neighbor_tables = filter_connectivities(self.offset_providers) @@ -190,7 +188,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. - The translation of `SetAt` ensures that the result is written to some global storage. + The translation of `SetAt` ensures that the result is written back to some global storage. """ expr_nodes = self._visit_expression(stmt.expr, sdfg, state) @@ -225,6 +223,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) def visit_FunCall( self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> Callable: + # first visit the argument nodes arg_builders: list[Callable] = [] for arg in node.args: arg_builder = self.visit(arg, sdfg=sdfg, head_state=head_state) @@ -241,15 +240,6 @@ def visit_FunCall( else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - @final - def visit_Lambda(self, node: itir.Lambda) -> Any: - """ - This visitor class should never encounter `itir.Lambda` expressions - because a lambda represents a stencil, which translates from iterator to values. - In fieldview, lambdas should only be arguments to field operators (`as_field_op`). - """ - raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GtirTaskletCodegen'.") - def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> Callable: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index ac70241098..5671539662 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -20,7 +20,7 @@ from gt4py.next.type_system import type_specifications as ts -def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: +def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" match type_.kind: case ts.ScalarKind.BOOL: @@ -53,6 +53,7 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne def unique_name(prefix: str) -> str: """Generate a string containing a unique integer id, which is updated incrementally.""" + unique_id = getattr(unique_name, "_unique_id", 0) # static variable setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 7113ed054c..b9b2db9c54 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -218,6 +218,7 @@ def test_gtir_select(): itir.Sym(id="w"), itir.Sym(id="z"), itir.Sym(id="cond"), + itir.Sym(id="scalar"), itir.Sym(id="size"), ], declarations=[], @@ -235,16 +236,16 @@ def test_gtir_select(): im.deref("cond"), im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("y"), + )("y", "scalar"), im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("w"), + )("w", "scalar"), ) )(), ), @@ -266,6 +267,7 @@ def test_gtir_select(): FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.FLOAT64), ts.ScalarType(ts.ScalarKind.INT32), ], OFFSET_PROVIDERS, @@ -275,7 +277,7 @@ def test_gtir_select(): assert isinstance(sdfg, dace.SDFG) for s in [False, True]: - sdfg(cond=s, x=a, y=b, w=c, z=d, **FSYMBOLS) + sdfg(cond=s, scalar=1, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) From 7e6909e061f3b3699b82843fd371c6e7bc81abfe Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 08:54:34 +0200 Subject: [PATCH 035/235] Resolve dace warnings --- .../gtir_builtins/gtir_builtin_domain.py | 13 +++++++------ .../gtir_builtins/gtir_builtin_translator.py | 5 ++++- .../runners_tests/test_dace_fieldview.py | 6 +++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py index 4e3f39594b..d7145a4461 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py @@ -47,13 +47,14 @@ def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: dim = Dimension(axis.value) bounds = [] for arg in named_range.args[1:3]: - bound = self.visit(arg) - if isinstance(bound, SymbolExpr): - assert bound.data in self.sdfg.symbols - bounds.append(bound.data) + if isinstance(arg, itir.Literal): + val = arg.value else: - assert isinstance(bound, LiteralExpr) - bounds.append(bound.value) + arg_expr = self.visit(arg) + assert isinstance(arg_expr, SymbolExpr) + val = arg_expr.data + assert val in self.sdfg.symbols + bounds.append(val) domain.append((dim, bounds[0], bounds[1])) return domain diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 81b9e4f03e..79e01bffe3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -219,7 +219,10 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: @final def visit_Literal(self, node: itir.Literal) -> LiteralExpr: - return LiteralExpr(node.value) + cast_sym = str(as_dace_type(node.type)) + cast_fmt = _MATH_BUILTINS_MAPPING[cast_sym] + typed_value = cast_fmt.format(node.value) + return LiteralExpr(typed_value) @final def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index b9b2db9c54..a044bdc4db 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -277,7 +277,7 @@ def test_gtir_select(): assert isinstance(sdfg, dace.SDFG) for s in [False, True]: - sdfg(cond=s, scalar=1, x=a, y=b, w=c, z=d, **FSYMBOLS) + sdfg(cond=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) @@ -351,5 +351,5 @@ def test_gtir_select_nested(): for s1 in [False, True]: for s2 in [False, True]: - sdfg(cond_1=s1, cond_2=s2, x=a, z=b, **FSYMBOLS) - assert np.allclose(b, (a + 1) if s1 else (a + 2) if s2 else (a + 3)) + sdfg(cond_1=np.bool_(s1), cond_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) + assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0)) From 2b07cc52c373eb9d3098e8e8b9dc6b65af48149b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 09:09:15 +0200 Subject: [PATCH 036/235] Remove bultin translator for domain expressions --- .../dace_fieldview/gtir_builtins/__init__.py | 4 -- .../gtir_builtins/gtir_builtin_domain.py | 60 ------------------- .../gtir_builtin_field_operator.py | 12 ++-- .../runners/dace_fieldview/gtir_to_sdfg.py | 8 +-- .../runners/dace_fieldview/utility.py | 31 +++++++++- 5 files changed, 39 insertions(+), 76 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py index 8376d7357e..c99b418eae 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py @@ -12,9 +12,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_domain import ( - GTIRBuiltinDomain as FieldDomain, -) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( GTIRBuiltinAsFieldOp as AsFieldOp, ) @@ -29,7 +26,6 @@ # export short names of translation classes for GTIR builtin functions __all__ = [ "AsFieldOp", - "FieldDomain", "Select", "SymbolRef", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py deleted file mode 100644 index d7145a4461..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_domain.py +++ /dev/null @@ -1,60 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import dace - -from gt4py.next.common import Dimension -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GTIRBuiltinTranslator, - LiteralExpr, - SymbolExpr, -) -from gt4py.next.type_system import type_specifications as ts - - -class GTIRBuiltinDomain(GTIRBuiltinTranslator): - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - raise NotImplementedError - - def visit_domain(self, node: itir.Expr) -> list[tuple[Dimension, str, str]]: - """ - Specialized visit method for domain expressions. - - Returns a list of dimensions and the corresponding range. - """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - - domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, itir.AxisLiteral) - dim = Dimension(axis.value) - bounds = [] - for arg in named_range.args[1:3]: - if isinstance(arg, itir.Literal): - val = arg.value - else: - arg_expr = self.visit(arg) - assert isinstance(arg_expr, SymbolExpr) - val = arg_expr.data - assert val in self.sdfg.symbols - bounds.append(val) - domain.append((dim, bounds[0], bounds[1])) - - return domain diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index d285724c1f..3d550c9fad 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -20,14 +20,11 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_domain import ( - GTIRBuiltinDomain, -) from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ValueExpr, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import unique_name +from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name from gt4py.next.type_system import type_specifications as ts @@ -57,14 +54,15 @@ def __init__( assert isinstance(domain_expr, itir.FunCall) # visit field domain - domain = GTIRBuiltinDomain(sdfg, state).visit_domain(domain_expr) + domain = get_domain(domain_expr) + sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - self.field_domain = {dim: (lb, ub) for dim, lb, ub in domain} - self.field_type = ts.FieldType([dim for dim, _, _ in domain], node_type) + self.field_domain = domain + self.field_type = ts.FieldType(sorted_domain_dims, node_type) self.stencil_expr = stencil_expr self.stencil_args = stencil_args diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 90f10ff8ac..92bf966293 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -29,6 +29,7 @@ from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, filter_connectivities, + get_domain, ) from gt4py.next.type_system import type_specifications as ts @@ -197,9 +198,8 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) # in case the statement returns more than one field target_nodes = self._visit_expression(stmt.target, sdfg, state) - domain = gtir_builtins.FieldDomain(sdfg, state).visit_domain(stmt.domain) - # convert domain to dictionary to ease access to dimension boundaries - domain_map = {dim: (lb, ub) for dim, lb, ub in domain} + # convert domain expression to dictionary to ease access to dimension boundaries + domain = get_domain(stmt.domain) for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] @@ -208,7 +208,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) if isinstance(target_field_type, ts.FieldType): subset = ",".join( - f"{domain_map[dim][0]}:{domain_map[dim][1]}" for dim in target_field_type.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_field_type.dims ) else: assert len(domain) == 0 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 5671539662..4e13d7d7ca 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,9 @@ import dace -from gt4py.next.common import Connectivity +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts @@ -51,6 +53,33 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne } +def get_domain(node: itir.Expr) -> dict[Dimension, tuple[str, str]]: + """ + Specialized visit method for domain expressions. + + Returns a list of dimensions and the corresponding range. + """ + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + + domain = {} + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, itir.AxisLiteral) + dim = Dimension(axis.value) + bounds = [] + for arg in named_range.args[1:3]: + if isinstance(arg, itir.Literal): + val = arg.value + else: + val = str(arg) + bounds.append(val) + domain[dim] = (bounds[0], bounds[1]) + + return domain + + def unique_name(prefix: str) -> str: """Generate a string containing a unique integer id, which is updated incrementally.""" From 2370fa6992ddc64647da973c81f801da4bca09a5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 09:56:11 +0200 Subject: [PATCH 037/235] Remove bultin translator for domain expressions (1) --- .../gtir_builtins/gtir_builtin_field_operator.py | 6 +++--- .../gtir_builtins/gtir_builtin_translator.py | 6 ++++++ .../runners/dace_fieldview/utility.py | 12 ++++++------ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 3d550c9fad..128c197ac3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -33,7 +33,7 @@ class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): stencil_expr: itir.Lambda stencil_args: list[Callable] - field_domain: dict[Dimension, tuple[str, str]] + field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType def __init__( @@ -53,8 +53,8 @@ def __init__( # the domain of the field operator is passed as second argument assert isinstance(domain_expr, itir.FunCall) - # visit field domain domain = get_domain(domain_expr) + # define field domain with all dimensions in alphabetical order sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) # add local storage to compute the field operator over the given domain @@ -77,7 +77,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # allocate local temporary storage for the result field field_shape = [ # diff between upper and lower bound - f"{self.field_domain[dim][1]} - {self.field_domain[dim][0]}" + self.field_domain[dim][1] - self.field_domain[dim][0] for dim in self.field_type.dims ] field_node = self.add_local_storage(self.field_type, field_shape) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 79e01bffe3..069289b19a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -84,16 +84,22 @@ @dataclass(frozen=True) class LiteralExpr: + """Any symbolic expression that can be evaluated at compile time.""" + value: dace.symbolic.SymbolicType @dataclass(frozen=True) class SymbolExpr: + """The data access to a scalar or field through a symbolic reference.""" + data: str @dataclass(frozen=True) class ValueExpr: + """The result of a computation provided by a tasklet node.""" + node: dace.nodes.Tasklet connector: str diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 4e13d7d7ca..bf74d582a7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -53,7 +53,9 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne } -def get_domain(node: itir.Expr) -> dict[Dimension, tuple[str, str]]: +def get_domain( + node: itir.Expr, +) -> dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Specialized visit method for domain expressions. @@ -70,11 +72,9 @@ def get_domain(node: itir.Expr) -> dict[Dimension, tuple[str, str]]: dim = Dimension(axis.value) bounds = [] for arg in named_range.args[1:3]: - if isinstance(arg, itir.Literal): - val = arg.value - else: - val = str(arg) - bounds.append(val) + str_val = str(arg) + sym_val = dace.symbolic.SymExpr(str_val) + bounds.append(sym_val) domain[dim] = (bounds[0], bounds[1]) return domain From 8e801df26b51ead9472fa63c4453cd0eee480104 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 15:22:38 +0200 Subject: [PATCH 038/235] Refactor --- .../gtir_builtin_field_operator.py | 177 +++++++++++++++++- .../gtir_builtins/gtir_builtin_select.py | 9 +- .../gtir_builtins/gtir_builtin_translator.py | 172 +---------------- .../runners/dace_fieldview/gtir_to_sdfg.py | 20 +- .../runners/dace_fieldview/utility.py | 20 ++ 5 files changed, 213 insertions(+), 185 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 128c197ac3..7b9701ee58 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -13,28 +13,115 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable +from dataclasses import dataclass +from typing import Callable, TypeAlias import dace +import numpy as np -from gt4py.next.common import Dimension +from gt4py import eve +from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, - ValueExpr, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name +from gt4py.next.program_processors.runners.dace_fieldview.utility import ( + as_dace_type, + get_domain, + unique_name, +) from gt4py.next.type_system import type_specifications as ts -class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): +_MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +@dataclass(frozen=True) +class LiteralExpr: + """Any symbolic expression that can be evaluated at compile time.""" + + value: dace.symbolic.SymbolicType + + +@dataclass(frozen=True) +class SymbolExpr: + """The data access to a scalar or field through a symbolic reference.""" + + data: str + + +@dataclass(frozen=True) +class ValueExpr: + """The result of a computation provided by a tasklet node.""" + + node: dace.nodes.Tasklet + connector: str + + +class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator, eve.NodeVisitor): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] + stencil_expr: itir.Lambda stencil_args: list[Callable] field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType + input_connections: list[TaskletConnector] + offset_provider: dict[str, Connectivity | Dimension] def __init__( self, @@ -42,8 +129,11 @@ def __init__( state: dace.SDFGState, node: itir.FunCall, stencil_args: list[Callable], + offset_provider: dict[str, Connectivity | Dimension], ): super().__init__(sdfg, state) + self.input_connections = [] + self.offset_provider = offset_provider assert cpm.is_call_to(node.fun, "as_fieldop") assert len(node.fun.args) == 2 @@ -129,3 +219,80 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) return [(field_node, self.field_type)] + + def _visit_shift(self, node: itir.FunCall) -> str: + raise NotImplementedError + + def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: + inp_tasklets = {} + inp_symbols = set() + inp_connectors = set() + node_internals = [] + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, LiteralExpr): + # use the value without adding any connector + node_internals.append(arg_expr.value) + else: + if isinstance(arg_expr, ValueExpr): + # the value is the result of a tasklet node + connector = f"__inp_{i}" + inp_tasklets[connector] = arg_expr + else: + # the value is the result of a tasklet node + assert isinstance(arg_expr, SymbolExpr) + connector = f"__inp_{arg_expr.data}" + inp_symbols.add((connector, arg_expr.data)) + inp_connectors.add(connector) + node_internals.append(connector) + + if cpm.is_call_to(node, "deref"): + assert len(inp_tasklets) == 0 + assert len(inp_symbols) == 1 + _, data = inp_symbols.pop() + return SymbolExpr(data) + + elif cpm.is_call_to(node.fun, "shift"): + code = self._visit_shift(node.fun) + + elif isinstance(node.fun, itir.SymRef): + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: + fmt = _MATH_BUILTINS_MAPPING[builtin_name] + code = fmt.format(*node_internals) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + + out_connector = "__out" + tasklet_node = self.head_state.add_tasklet( + unique_name("tasklet"), + inp_connectors, + {out_connector}, + "{} = {}".format(out_connector, code), + ) + + else: + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + for input_conn, inp_expr in inp_tasklets.items(): + self.head_state.add_edge( + inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() + ) + self.input_connections.extend( + ((tasklet_node, connector), data) for connector, data in inp_symbols + ) + return ValueExpr(tasklet_node, out_connector) + + def visit_Literal(self, node: itir.Literal) -> LiteralExpr: + cast_sym = str(as_dace_type(node.type)) + cast_fmt = _MATH_BUILTINS_MAPPING[cast_sym] + typed_value = cast_fmt.format(node.value) + return LiteralExpr(typed_value) + + def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: + """ + Symbol references are mapped to tasklet connectors that access some kind of data. + """ + sym_name = str(node.id) + return SymbolExpr(sym_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index 1c8a31ad36..7f7f9bacbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -22,8 +22,8 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, - SymbolExpr, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import get_symbolic_expr from gt4py.next.type_system import type_specifications as ts @@ -47,8 +47,7 @@ def __init__( cond_expr, true_expr, false_expr = node.fun.args # expect condition as first argument - cond = self.visit(cond_expr) - assert isinstance(cond, SymbolExpr) + cond = get_symbolic_expr(cond_expr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: @@ -70,13 +69,13 @@ def __init__( # expect true branch as second argument true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond.data)) + sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) sdfg.add_edge(true_state, state, dace.InterstateEdge()) self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=f"not {cond.data}")) + sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}"))) sdfg.add_edge(false_state, state, dace.InterstateEdge()) self.false_br_builder = dataflow_builder.visit( false_expr, sdfg=sdfg, head_state=false_state diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 069289b19a..4063fe80be 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -15,106 +15,18 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, TypeAlias, final +from typing import final import dace -import numpy as np -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name from gt4py.next.type_system import type_specifications as ts -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -@dataclass(frozen=True) -class LiteralExpr: - """Any symbolic expression that can be evaluated at compile time.""" - - value: dace.symbolic.SymbolicType - - -@dataclass(frozen=True) -class SymbolExpr: - """The data access to a scalar or field through a symbolic reference.""" - - data: str - - @dataclass(frozen=True) -class ValueExpr: - """The result of a computation provided by a tasklet node.""" - - node: dace.nodes.Tasklet - connector: str - - -class GTIRBuiltinTranslator(eve.NodeVisitor): - TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] - +class GTIRBuiltinTranslator: sdfg: dace.SDFG head_state: dace.SDFGState - input_connections: list[TaskletConnector] - - def __init__(self, sdfg: dace.SDFG, head_state: dace.SDFGState): - self.sdfg = sdfg - self.head_state = head_state - self.input_connections = [] @final def __call__( @@ -157,83 +69,3 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: The GT4Py data type is useful in the case of fields, because it provides information on the field domain (e.g. order of dimensions, types of dimensions). """ - - def _visit_deref(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - @final - def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - - elif isinstance(node.fun, itir.SymRef): - # create a tasklet node implementing the builtin function - inputs = {} - inp_data = set() - inp_nodes = set() - node_internals = [] - for i, arg in enumerate(node.args): - arg_expr = self.visit(arg) - if isinstance(arg_expr, LiteralExpr): - node_internals.append(arg_expr.value) - else: - connector = f"__inp_{i}" - if isinstance(arg_expr, ValueExpr): - inputs[connector] = arg_expr - else: - assert isinstance(arg_expr, SymbolExpr) - inp_data.add((connector, arg_expr.data)) - inp_nodes.add(connector) - node_internals.append(connector) - - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - fmt = _MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - - out_connector = "__out" - tasklet_node = self.head_state.add_tasklet( - unique_name("tasklet"), - inp_nodes, - {out_connector}, - "{} = {}".format(out_connector, code), - ) - for input_conn, inp_expr in inputs.items(): - self.head_state.add_edge( - inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() - ) - self.input_connections.extend( - ((tasklet_node, connector), data) for connector, data in inp_data - ) - return ValueExpr(tasklet_node, out_connector) - - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - - @final - def visit_Lambda(self, node: itir.Lambda) -> Any: - """ - This visitor class should never encounter `itir.Lambda` expressions - because a lambda represents a stencil, which operates from iterator to values. - In fieldview, lambdas should only be arguments to field operators (`as_field_op`). - """ - raise RuntimeError("Unexpected 'itir.Lambda' node encountered by 'GTIRBuiltinTranslator'.") - - @final - def visit_Literal(self, node: itir.Literal) -> LiteralExpr: - cast_sym = str(as_dace_type(node.type)) - cast_fmt = _MATH_BUILTINS_MAPPING[cast_sym] - typed_value = cast_fmt.format(node.value) - return LiteralExpr(typed_value) - - @final - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: - """ - Symbol references are mapped to tasklet connectors that access some kind of data. - """ - sym_name = str(node.id) - return SymbolExpr(sym_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 92bf966293..79625a15ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -49,16 +49,16 @@ class GTIRToSDFG(eve.NodeVisitor): data_types: dict[str, ts.FieldType | ts.ScalarType] param_types: list[ts.DataType] - offset_providers: Mapping[str, Any] + offset_provider: dict[str, Connectivity | Dimension] def __init__( self, param_types: list[ts.DataType], - offset_providers: dict[str, Connectivity | Dimension], + offset_provider: dict[str, Connectivity | Dimension], ): self.data_types = {} self.param_types = param_types - self.offset_providers = offset_providers + self.offset_provider = offset_provider def _make_array_shape_and_strides( self, name: str, dims: Sequence[Dimension] @@ -73,7 +73,7 @@ def _make_array_shape_and_strides( Two list of symbols, one for the shape and another for the strides of the array. """ dtype = dace.int32 - neighbor_tables = filter_connectivities(self.offset_providers) + neighbor_tables = filter_connectivities(self.offset_provider) shape = [ ( neighbor_tables[dim.value].max_neighbors @@ -231,7 +231,9 @@ def visit_FunCall( arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtins.AsFieldOp(sdfg, head_state, node, arg_builders) + return gtir_builtins.AsFieldOp( + sdfg, head_state, node, arg_builders, self.offset_provider + ) elif cpm.is_call_to(node.fun, "select"): assert len(arg_builders) == 0 @@ -240,6 +242,14 @@ def visit_FunCall( else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + def visit_Lambda(self, node: itir.Lambda) -> Any: + """ + This visitor class should never encounter `itir.Lambda` expressions + because a lambda represents a stencil, which operates from iterator to values. + In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + """ + raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.") + def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> Callable: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index bf74d582a7..06034f80e8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,12 +16,26 @@ import dace +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts +class SymbolicTranslator(codegen.TemplatedGenerator): + SymRef = as_fmt("{id}") + Literal = as_fmt("{value}") + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + return self.visit(node.args[0]) + + raise RuntimeError(f"Unexpected 'FunCall' expression ({node}).") + + def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" match type_.kind: @@ -80,6 +94,12 @@ def get_domain( return domain +def get_symbolic_expr( + node: itir.Expr, +) -> str: + return SymbolicTranslator().visit(node) + + def unique_name(prefix: str) -> str: """Generate a string containing a unique integer id, which is updated incrementally.""" From 812a6e5cf91bc604887df9f4a350aece9a9f0e76 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 15:35:27 +0200 Subject: [PATCH 039/235] Minor edit --- .../gtir_builtin_field_operator.py | 63 ++-------------- .../runners/dace_fieldview/gtir_types.py | 71 +++++++++++++++++++ .../runners/dace_fieldview/utility.py | 33 ++++++--- 3 files changed, 99 insertions(+), 68 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 7b9701ee58..03519b19d2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -17,7 +17,6 @@ from typing import Callable, TypeAlias import dace -import numpy as np from gt4py import eve from gt4py.next.common import Connectivity, Dimension @@ -26,6 +25,7 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, get_domain, @@ -34,61 +34,6 @@ from gt4py.next.type_system import type_specifications as ts -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - @dataclass(frozen=True) class LiteralExpr: """Any symbolic expression that can be evaluated at compile time.""" @@ -258,8 +203,8 @@ def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: elif isinstance(node.fun, itir.SymRef): # create a tasklet node implementing the builtin function builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - fmt = _MATH_BUILTINS_MAPPING[builtin_name] + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] code = fmt.format(*node_internals) else: raise NotImplementedError(f"'{builtin_name}' not implemented.") @@ -286,7 +231,7 @@ def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: def visit_Literal(self, node: itir.Literal) -> LiteralExpr: cast_sym = str(as_dace_type(node.type)) - cast_fmt = _MATH_BUILTINS_MAPPING[cast_sym] + cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] typed_value = cast_fmt.format(node.value) return LiteralExpr(typed_value) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py new file mode 100644 index 0000000000..9e190e1395 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py @@ -0,0 +1,71 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import numpy as np + + +MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 06034f80e8..b6f9cf07de 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -21,6 +21,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING from gt4py.next.type_system import type_specifications as ts @@ -28,12 +29,28 @@ class SymbolicTranslator(codegen.TemplatedGenerator): SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - assert len(node.args) == 1 + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) - raise RuntimeError(f"Unexpected 'FunCall' expression ({node}).") + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: @@ -86,17 +103,15 @@ def get_domain( dim = Dimension(axis.value) bounds = [] for arg in named_range.args[1:3]: - str_val = str(arg) - sym_val = dace.symbolic.SymExpr(str_val) + sym_str = get_symbolic_expr(arg) + sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) domain[dim] = (bounds[0], bounds[1]) return domain -def get_symbolic_expr( - node: itir.Expr, -) -> str: +def get_symbolic_expr(node: itir.Expr) -> str: return SymbolicTranslator().visit(node) From 1d0b50bdc57e845e973f71b1f7c996376446d3f1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 17:39:51 +0200 Subject: [PATCH 040/235] Extract ITIR visitor to separate class --- .../gtir_builtin_field_operator.py | 125 ++------------- .../runners/dace_fieldview/gtir_to_tasklet.py | 151 ++++++++++++++++++ 2 files changed, 161 insertions(+), 115 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 03519b19d2..2191277cd2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -13,50 +13,25 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass from typing import Callable, TypeAlias import dace -from gt4py import eve from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING -from gt4py.next.program_processors.runners.dace_fieldview.utility import ( - as_dace_type, - get_domain, - unique_name, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( + GTIRToTasklet, + ValueExpr, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name from gt4py.next.type_system import type_specifications as ts -@dataclass(frozen=True) -class LiteralExpr: - """Any symbolic expression that can be evaluated at compile time.""" - - value: dace.symbolic.SymbolicType - - -@dataclass(frozen=True) -class SymbolExpr: - """The data access to a scalar or field through a symbolic reference.""" - - data: str - - -@dataclass(frozen=True) -class ValueExpr: - """The result of a computation provided by a tasklet node.""" - - node: dace.nodes.Tasklet - connector: str - - -class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator, eve.NodeVisitor): +class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] @@ -65,7 +40,6 @@ class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator, eve.NodeVisitor): stencil_args: list[Callable] field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType - input_connections: list[TaskletConnector] offset_provider: dict[str, Connectivity | Dimension] def __init__( @@ -77,7 +51,6 @@ def __init__( offset_provider: dict[str, Connectivity | Dimension], ): super().__init__(sdfg, state) - self.input_connections = [] self.offset_provider = offset_provider assert cpm.is_call_to(node.fun, "as_fieldop") @@ -102,11 +75,10 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - assert len(self.input_connections) == 0 - # generate a tasklet node implementing the stencil function and represent # the field operator as a mapped tasklet, which will range over the field domain - output_expr = self.visit(self.stencil_expr.expr) + taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) + input_connections, output_expr = taskgen.visit(self.stencil_expr) assert isinstance(output_expr, ValueExpr) # allocate local temporary storage for the result field @@ -150,13 +122,13 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for (input_node, input_connector), input_param in self.input_connections: + for input_expr, input_param in input_connections: assert input_param in data_nodes self.head_state.add_memlet_path( data_nodes[input_param], me, - input_node, - dst_conn=input_connector, + input_expr.node, + dst_conn=input_expr.connector, memlet=input_memlets[input_param], ) self.head_state.add_memlet_path( @@ -164,80 +136,3 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) return [(field_node, self.field_type)] - - def _visit_shift(self, node: itir.FunCall) -> str: - raise NotImplementedError - - def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: - inp_tasklets = {} - inp_symbols = set() - inp_connectors = set() - node_internals = [] - for i, arg in enumerate(node.args): - arg_expr = self.visit(arg) - if isinstance(arg_expr, LiteralExpr): - # use the value without adding any connector - node_internals.append(arg_expr.value) - else: - if isinstance(arg_expr, ValueExpr): - # the value is the result of a tasklet node - connector = f"__inp_{i}" - inp_tasklets[connector] = arg_expr - else: - # the value is the result of a tasklet node - assert isinstance(arg_expr, SymbolExpr) - connector = f"__inp_{arg_expr.data}" - inp_symbols.add((connector, arg_expr.data)) - inp_connectors.add(connector) - node_internals.append(connector) - - if cpm.is_call_to(node, "deref"): - assert len(inp_tasklets) == 0 - assert len(inp_symbols) == 1 - _, data = inp_symbols.pop() - return SymbolExpr(data) - - elif cpm.is_call_to(node.fun, "shift"): - code = self._visit_shift(node.fun) - - elif isinstance(node.fun, itir.SymRef): - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - - out_connector = "__out" - tasklet_node = self.head_state.add_tasklet( - unique_name("tasklet"), - inp_connectors, - {out_connector}, - "{} = {}".format(out_connector, code), - ) - - else: - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - - for input_conn, inp_expr in inp_tasklets.items(): - self.head_state.add_edge( - inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() - ) - self.input_connections.extend( - ((tasklet_node, connector), data) for connector, data in inp_symbols - ) - return ValueExpr(tasklet_node, out_connector) - - def visit_Literal(self, node: itir.Literal) -> LiteralExpr: - cast_sym = str(as_dace_type(node.type)) - cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] - typed_value = cast_fmt.format(node.value) - return LiteralExpr(typed_value) - - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: - """ - Symbol references are mapped to tasklet connectors that access some kind of data. - """ - sym_name = str(node.id) - return SymbolExpr(sym_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py new file mode 100644 index 0000000000..45dd8c278e --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -0,0 +1,151 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from dataclasses import dataclass + +import dace + +from gt4py import eve +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING +from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name + + +@dataclass(frozen=True) +class LiteralExpr: + """Any symbolic expression that can be evaluated at compile time.""" + + value: dace.symbolic.SymbolicType + + +@dataclass(frozen=True) +class SymbolExpr: + """The data access to a scalar or field through a symbolic reference.""" + + data: str + + +@dataclass(frozen=True) +class ValueExpr: + """The result of a computation provided by a tasklet node.""" + + node: dace.nodes.Tasklet + connector: str + + +class GTIRToTasklet(eve.NodeVisitor): + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + + sdfg: dace.SDFG + state: dace.SDFGState + input_connections: list[tuple[ValueExpr, str]] + offset_provider: dict[str, Connectivity | Dimension] + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + offset_provider: dict[str, Connectivity | Dimension], + ): + self.sdfg = sdfg + self.state = state + self.input_connections = [] + self.offset_provider = offset_provider + + def _visit_shift(self, node: itir.FunCall) -> str: + assert len(node.args) == 2 + raise NotImplementedError + + def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: + inp_tasklets = {} + inp_symbols = set() + inp_connectors = set() + node_internals = [] + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, LiteralExpr): + # use the value without adding any connector + node_internals.append(arg_expr.value) + else: + if isinstance(arg_expr, ValueExpr): + # the value is the result of a tasklet node + connector = f"__inp_{i}" + inp_tasklets[connector] = arg_expr + else: + # the value is the result of a tasklet node + assert isinstance(arg_expr, SymbolExpr) + connector = f"__inp_{arg_expr.data}" + inp_symbols.add((connector, arg_expr.data)) + inp_connectors.add(connector) + node_internals.append(connector) + + if cpm.is_call_to(node, "deref"): + assert len(inp_tasklets) == 0 + assert len(inp_symbols) == 1 + _, data = inp_symbols.pop() + return SymbolExpr(data) + + elif cpm.is_call_to(node.fun, "shift"): + code = self._visit_shift(node.fun) + + elif isinstance(node.fun, itir.SymRef): + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] + code = fmt.format(*node_internals) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + + out_connector = "__out" + tasklet_node = self.state.add_tasklet( + unique_name("tasklet"), + inp_connectors, + {out_connector}, + "{} = {}".format(out_connector, code), + ) + + else: + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + for input_conn, inp_expr in inp_tasklets.items(): + self.state.add_edge( + inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() + ) + self.input_connections.extend( + (ValueExpr(tasklet_node, connector), data) for connector, data in inp_symbols + ) + return ValueExpr(tasklet_node, out_connector) + + def visit_Lambda(self, node: itir.Lambda) -> tuple[list[tuple[ValueExpr, str]], ValueExpr]: + assert len(self.input_connections) == 0 + output_expr = self.visit(node.expr) + assert isinstance(output_expr, ValueExpr) + return self.input_connections, output_expr + + def visit_Literal(self, node: itir.Literal) -> LiteralExpr: + cast_sym = str(as_dace_type(node.type)) + cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] + typed_value = cast_fmt.format(node.value) + return LiteralExpr(typed_value) + + def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: + """ + Symbol references are mapped to tasklet connectors that access some kind of data. + """ + sym_name = str(node.id) + return SymbolExpr(sym_name) From 97a1d229e327a0b652e87129e4fc077c676ba8de Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 18:36:23 +0200 Subject: [PATCH 041/235] Code refactoring --- .../{gtir_types.py => gtir_python_codegen.py} | 34 ++++++++++++++++++ .../runners/dace_fieldview/gtir_to_tasklet.py | 4 ++- .../runners/dace_fieldview/utility.py | 36 +++---------------- 3 files changed, 41 insertions(+), 33 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_types.py => gtir_python_codegen.py} (60%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py similarity index 60% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 9e190e1395..a839bbbd3f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_types.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,7 +14,12 @@ import numpy as np +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm MATH_BUILTINS_MAPPING = { "abs": "abs({})", @@ -69,3 +74,32 @@ "mod": "({} % {})", "not_": "(not {})", # ~ is not bitwise in numpy } + + + +class GTIRPythonCodegen(codegen.TemplatedGenerator): + SymRef = as_fmt("{id}") + Literal = as_fmt("{value}") + + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 45dd8c278e..f538198eb9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -21,7 +21,9 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING +from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( + MATH_BUILTINS_MAPPING, +) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index b6f9cf07de..2e8a205ca8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,43 +16,15 @@ import dace -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_types import MATH_BUILTINS_MAPPING +from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( + GTIRPythonCodegen, +) from gt4py.next.type_system import type_specifications as ts -class SymbolicTranslator(codegen.TemplatedGenerator): - SymRef = as_fmt("{id}") - Literal = as_fmt("{value}") - - def _visit_deref(self, node: itir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - elif isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - - def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" match type_.kind: @@ -112,7 +84,7 @@ def get_domain( def get_symbolic_expr(node: itir.Expr) -> str: - return SymbolicTranslator().visit(node) + return GTIRPythonCodegen().visit(node) def unique_name(prefix: str) -> str: From a30cc7d7383aa9f68d25a9089710c08b4c7d15cb Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 7 May 2024 18:38:43 +0200 Subject: [PATCH 042/235] Fix formatting --- .../runners/dace_fieldview/gtir_python_codegen.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index a839bbbd3f..478b5d3af8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,13 +14,13 @@ import numpy as np + from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt - -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm + MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -76,7 +76,6 @@ } - class GTIRPythonCodegen(codegen.TemplatedGenerator): SymRef = as_fmt("{id}") Literal = as_fmt("{value}") From f595b01d6be098b637ec2918de8f9a9dee2ed4e1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 10 May 2024 13:21:43 +0200 Subject: [PATCH 043/235] Add IteratorExpr type --- .../gtir_builtin_field_operator.py | 103 ++++--- .../runners/dace_fieldview/gtir_to_tasklet.py | 280 +++++++++++++----- .../runners_tests/test_dace_fieldview.py | 54 +++- 3 files changed, 319 insertions(+), 118 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 2191277cd2..87498c4d86 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -25,12 +25,23 @@ ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( GTIRToTasklet, + IteratorExpr, + SymbolExpr, + TaskletExpr, ValueExpr, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name +from gt4py.next.program_processors.runners.dace_fieldview.utility import ( + as_dace_type, + get_domain, + unique_name, +) from gt4py.next.type_system import type_specifications as ts +# Define type of variables used for field indexing +_INDEX_DTYPE = dace.int64 + + class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" @@ -75,11 +86,42 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - # generate a tasklet node implementing the stencil function and represent - # the field operator as a mapped tasklet, which will range over the field domain + # first visit the list of arguments and build a symbol map + data_types: dict[str, ts.FieldType | ts.ScalarType] = {} + stencil_args: list[IteratorExpr | ValueExpr] = [] + for arg in self.stencil_args: + arg_nodes = arg() + assert len(arg_nodes) == 1 + arg_node, arg_type = arg_nodes[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(arg_node, dace.nodes.AccessNode) + data_types[arg_node.data] = arg_type + + if isinstance(arg_type, ts.ScalarType): + dtype = as_dace_type(arg_type) + scalar_arg = ValueExpr(arg_node, dtype) + stencil_args.append(scalar_arg) + else: + assert isinstance(arg_type, ts.FieldType) + dtype = as_dace_type(arg_type.dtype) + indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] = { + dim.value: SymbolExpr(f"i_{dim.value}", _INDEX_DTYPE) + for dim in arg_type.dims + if dim in self.field_domain + } + iterator_arg = IteratorExpr( + arg_node, + [dim.value for dim in arg_type.dims], + [0] * len(arg_type.dims), + indices, + dtype, + ) + stencil_args.append(iterator_arg) + + # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr) - assert isinstance(output_expr, ValueExpr) + input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + assert isinstance(output_expr, TaskletExpr) # allocate local temporary storage for the result field field_shape = [ @@ -89,31 +131,6 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ] field_node = self.add_local_storage(self.field_type, field_shape) - data_nodes: dict[str, dace.nodes.AccessNode] = {} - input_memlets: dict[str, dace.Memlet] = {} - for arg, param in zip(self.stencil_args, self.stencil_expr.params, strict=True): - arg_nodes = arg() - assert len(arg_nodes) == 1 - arg_node, arg_type = arg_nodes[0] - data = str(param.id) - # require (for now) all input nodes to be data access nodes - assert isinstance(arg_node, dace.nodes.AccessNode) - data_nodes[data] = arg_node - if isinstance(arg_type, ts.FieldType): - # support either single element access (general case) or full array shape - is_scalar = all(dim in self.field_domain for dim in arg_type.dims) - if is_scalar: - subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) - input_memlets[data] = dace.Memlet(data=arg_node.data, subset=subset, volume=1) - else: - memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) - # set volume to 1 because the stencil function always performs single element access - # TODO: check validity of this assumption - memlet.volume = 1 - input_memlets[data] = memlet - else: - input_memlets[data] = dace.Memlet(data=arg_node.data, subset="0") - # assume tasklet with single output output_index = ",".join(f"i_{dim.value}" for dim in self.field_type.dims) output_memlet = dace.Memlet(data=field_node.data, subset=output_index) @@ -122,14 +139,28 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for input_expr, input_param in input_connections: - assert input_param in data_nodes + for arg_node, (lambda_node, lambda_connector, offset) in input_connections: + assert arg_node.data in data_types + arg_type = data_types[arg_node.data] + if len(lambda_node.in_connectors) != 1: + # indirection tasklet with explicit indexes + memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) + elif isinstance(arg_type, ts.ScalarType): + memlet = dace.Memlet(data=arg_node.data, subset="0", volume=1) + else: + # read one field element through memlet subset + assert all(dim in self.field_domain for dim in arg_type.dims) + subset = ",".join( + f"i_{dim.value} + ({offset})" + for dim, off in zip(arg_type.dims, offset, strict=True) + ) + memlet = dace.Memlet(data=arg_node.data, subset=subset, volume=1) self.head_state.add_memlet_path( - data_nodes[input_param], + arg_node, me, - input_expr.node, - dst_conn=input_expr.connector, - memlet=input_memlets[input_param], + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, ) self.head_state.add_memlet_path( output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index f538198eb9..e209177e0b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -13,7 +13,9 @@ # SPDX-License-Identifier: GPL-3.0-or-later +import itertools from dataclasses import dataclass +from typing import Optional import dace @@ -25,28 +27,43 @@ MATH_BUILTINS_MAPPING, ) from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name +from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) -class LiteralExpr: - """Any symbolic expression that can be evaluated at compile time.""" +class SymbolExpr: + """Any symbolic expression that is constant in the context of current SDFG.""" value: dace.symbolic.SymbolicType + dtype: dace.typeclass @dataclass(frozen=True) -class SymbolExpr: - """The data access to a scalar or field through a symbolic reference.""" +class TaskletExpr: + """Result of the computation provided by a tasklet node.""" - data: str + node: dace.nodes.Tasklet + connector: str + dtype: dace.typeclass @dataclass(frozen=True) class ValueExpr: - """The result of a computation provided by a tasklet node.""" + """Data provided by a scalar access node.""" - node: dace.nodes.Tasklet - connector: str + value: dace.nodes.AccessNode + dtype: dace.typeclass + + +@dataclass(frozen=True) +class IteratorExpr: + """Iterator to access the field provided by an array access node.""" + + field: dace.nodes.AccessNode + dimensions: list[str] + offset: list[int] + indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] + dtype: dace.typeclass class GTIRToTasklet(eve.NodeVisitor): @@ -54,8 +71,9 @@ class GTIRToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[tuple[ValueExpr, str]] + input_connections: list[tuple[dace.nodes.AccessNode, tuple[dace.nodes.Tasklet, str, list[int]]]] offset_provider: dict[str, Connectivity | Dimension] + symbol_map: dict[str, IteratorExpr | SymbolExpr | ValueExpr] def __init__( self, @@ -67,87 +85,199 @@ def __init__( self.state = state self.input_connections = [] self.offset_provider = offset_provider + self.symbol_map = {} + + def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: + assert len(node.args) == 1 + it = self.visit(node.args[0]) + + if isinstance(it, SymbolExpr): + cast_sym = str(it.dtype) + cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] + deref_node = self.state.add_tasklet( + "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" + ) + deref_expr = TaskletExpr(deref_node, "val", it.dtype) + + elif isinstance(it, ValueExpr): + deref_node = self.state.add_tasklet( + "deref_scalar", {"scalar"}, {"val"}, code="val = scalar" + ) + deref_expr = TaskletExpr(deref_node, "val", it.dtype) + + # add new termination point for the data access + self.input_connections.append((it.value, (deref_node, "scalar", []))) + + else: + if all(isinstance(index, SymbolExpr) for index in it.indices.values()): + # use direct field access through memlet subset + deref_node = self.state.add_tasklet( + "deref_field", {"field"}, {"val"}, code="val = field" + ) + else: + index_connectors = [ + f"i_{dim}" + for dim, index in it.indices.items() + if isinstance(index, TaskletExpr | ValueError) + ] + sorted_indices = [it.indices[dim] for dim in it.dimensions] + index_internals = ",".join( + index.value if isinstance(index, SymbolExpr) else f"i_{dim}" + for dim, index in zip(it.dimensions, sorted_indices) + ) + deref_node = self.state.add_tasklet( + "deref_field_indirecton", + set("field", *index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + for dim, index_expr in it.indices.items(): + deref_connector = f"i_{dim}" + if isinstance(index_expr, TaskletExpr): + self.state.add_edge( + index_expr.node, + index_expr.connector, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.node.data, subset="0"), + ) + elif isinstance(index_expr, ValueExpr): + self.state.add_edge( + index_expr.value, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.value.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) + + deref_expr = TaskletExpr(deref_node, "val", it.dtype) + + # add new termination point for this field parameter + self.input_connections.append((it.field, (deref_node, "field", it.offset))) + + return deref_expr + + def _split_shift_args( + self, args: list[itir.Expr] + ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: + pairs = [args[i : i + 2] for i in range(0, len(args), 2)] + assert len(pairs) >= 1 + assert all(len(pair) == 2 for pair in pairs) + return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None + + def _make_shift_for_rest(self, rest: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: + return itir.FunCall( + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), + args=[iterator], + ) + + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: + shift_node = node.fun + assert isinstance(shift_node, itir.FunCall) + + # the iterator to be shifted is the argument to the function node + head, tail = self._split_shift_args(shift_node.args) + if tail: + it = self.visit(self._make_shift_for_rest(tail, node.args[0])) + else: + it = self.visit(node.args[0]) + assert isinstance(it, IteratorExpr) + + # first argument of the shift node is the offset provider + assert isinstance(head[0], itir.OffsetLiteral) + offset = head[0].value + assert isinstance(offset, str) + offset_provider = self.offset_provider[offset] + # second argument should be the offset value + if isinstance(head[1], itir.OffsetLiteral): + offset_value = head[1].value + assert isinstance(offset_value, int) + else: + raise NotImplementedError("Dynamic offset not supported.") + + if isinstance(offset_provider, Dimension): + # cartesian offset along one dimension + dim_index = it.dimensions.index(offset_provider.value) + new_offset = [ + prev_offset + offset_value if i == dim_index else prev_offset + for i, prev_offset in enumerate(it.offset) + ] + shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices, it.dtype) + else: + # shift in unstructured domain by means of a neighbor table + raise NotImplementedError + + return shifted_it + + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + + elif cpm.is_call_to(node.fun, "shift"): + return self._visit_shift(node) - def _visit_shift(self, node: itir.FunCall) -> str: - assert len(node.args) == 2 - raise NotImplementedError + else: + assert isinstance(node.fun, itir.SymRef) - def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: - inp_tasklets = {} - inp_symbols = set() - inp_connectors = set() node_internals = [] + node_connections = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, LiteralExpr): - # use the value without adding any connector + if isinstance(arg_expr, SymbolExpr): + # use the argument value without adding any connector node_internals.append(arg_expr.value) else: - if isinstance(arg_expr, ValueExpr): - # the value is the result of a tasklet node - connector = f"__inp_{i}" - inp_tasklets[connector] = arg_expr - else: - # the value is the result of a tasklet node - assert isinstance(arg_expr, SymbolExpr) - connector = f"__inp_{arg_expr.data}" - inp_symbols.add((connector, arg_expr.data)) - inp_connectors.add(connector) + assert isinstance(arg_expr, TaskletExpr) + # the argument value is the result of a tasklet node + connector = f"__inp_{i}" + node_connections[connector] = arg_expr node_internals.append(connector) - if cpm.is_call_to(node, "deref"): - assert len(inp_tasklets) == 0 - assert len(inp_symbols) == 1 - _, data = inp_symbols.pop() - return SymbolExpr(data) - - elif cpm.is_call_to(node.fun, "shift"): - code = self._visit_shift(node.fun) - - elif isinstance(node.fun, itir.SymRef): - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - - out_connector = "__out" - tasklet_node = self.state.add_tasklet( - unique_name("tasklet"), - inp_connectors, - {out_connector}, - "{} = {}".format(out_connector, code), - ) + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] + code = fmt.format(*node_internals) else: - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + raise NotImplementedError(f"'{builtin_name}' not implemented.") + + out_connector = "result" + tasklet_node = self.state.add_tasklet( + unique_name("tasklet"), + node_connections.keys(), + {out_connector}, + "{} = {}".format(out_connector, code), + ) - for input_conn, inp_expr in inp_tasklets.items(): + for connector, arg_expr in node_connections.items(): self.state.add_edge( - inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() + arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() ) - self.input_connections.extend( - (ValueExpr(tasklet_node, connector), data) for connector, data in inp_symbols - ) - return ValueExpr(tasklet_node, out_connector) - def visit_Lambda(self, node: itir.Lambda) -> tuple[list[tuple[ValueExpr, str]], ValueExpr]: - assert len(self.input_connections) == 0 + dtype = as_dace_type(node_type) + return TaskletExpr(tasklet_node, "result", dtype) + + def visit_Lambda( + self, node: itir.Lambda, args: list[IteratorExpr | SymbolExpr | ValueExpr] + ) -> tuple[ + list[tuple[dace.nodes.AccessNode, tuple[dace.nodes.Tasklet, str, list[int]]]], TaskletExpr + ]: + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg output_expr = self.visit(node.expr) - assert isinstance(output_expr, ValueExpr) + assert isinstance(output_expr, TaskletExpr) return self.input_connections, output_expr - def visit_Literal(self, node: itir.Literal) -> LiteralExpr: - cast_sym = str(as_dace_type(node.type)) - cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] - typed_value = cast_fmt.format(node.value) - return LiteralExpr(typed_value) - - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: - """ - Symbol references are mapped to tasklet connectors that access some kind of data. - """ - sym_name = str(node.id) - return SymbolExpr(sym_name) + def visit_Literal(self, node: itir.Literal) -> SymbolExpr: + dtype = as_dace_type(node.type) + return SymbolExpr(node.value, dtype) + + def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | SymbolExpr | ValueExpr: + param = str(node.id) + assert param in self.symbol_map + return self.symbol_map[param] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index a044bdc4db..c898ced98c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -30,12 +30,13 @@ import pytest +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim + dace = pytest.importorskip("dace") N = 10 -DIM = Dimension("D") -FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) FSYMBOLS = dict( __w_size_0=N, __w_stride_0=1, @@ -52,7 +53,7 @@ def test_gtir_sum2(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="sum_2fields", @@ -90,7 +91,7 @@ def test_gtir_sum2(): def test_gtir_sum2_sym(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="sum_2fields", @@ -127,7 +128,7 @@ def test_gtir_sum2_sym(): def test_gtir_sum3(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee_fieldview = itir.Program( id="sum_3fields", @@ -207,7 +208,7 @@ def test_gtir_sum3(): def test_gtir_select(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="select_2sums", @@ -283,7 +284,7 @@ def test_gtir_select(): def test_gtir_select_nested(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="select_nested", @@ -353,3 +354,42 @@ def test_gtir_select_nested(): for s2 in [False, True]: sdfg(cond_1=np.bool_(s1), cond_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0)) + + +def test_gtir_cartesian_shift(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = itir.Program( + id="caresian_shift", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref(im.shift("IDim", 1)("a")), 1)), + domain, + ) + )("x"), + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N + 1) + b = np.empty(N) + + sdfg_genenerator = FieldviewGtirToSDFG( + [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim} + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + FSYMBOLS_tmp = FSYMBOLS.copy() + FSYMBOLS_tmp["__x_size_0"] = N + 1 + sdfg(x=a, y=b, **FSYMBOLS_tmp) + assert np.allclose(a[1:] + 1, b) From a6bcb6cd9e23138daaffc1923c5f477dcb672af1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 10 May 2024 18:20:25 +0200 Subject: [PATCH 044/235] Indirection shift implemented as tasklet node --- .../gtir_builtin_field_operator.py | 24 +++-- .../runners/dace_fieldview/gtir_to_sdfg.py | 22 +++- .../runners/dace_fieldview/gtir_to_tasklet.py | 46 ++++++-- .../runners/dace_fieldview/utility.py | 4 + .../runners_tests/test_dace_fieldview.py | 102 +++++++++++++++--- 5 files changed, 164 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 87498c4d86..f9847ccaee 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -35,7 +35,7 @@ get_domain, unique_name, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_specifications as ts, type_translation as tt # Define type of variables used for field indexing @@ -139,10 +139,20 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for arg_node, (lambda_node, lambda_connector, offset) in input_connections: - assert arg_node.data in data_types - arg_type = data_types[arg_node.data] - if len(lambda_node.in_connectors) != 1: + for arg_node, (lambda_node, lambda_connector, index_offset) in input_connections: + offset = None + if arg_node.data.startswith("__connectivity"): + self.sdfg.arrays[arg_node.data].transient = False + offset = arg_node.data.removeprefix("__connectivity") + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, Connectivity) + type_ = tt.from_type_hint(offset_provider.index_type) + assert isinstance(type_, ts.ScalarType) + arg_type = ts.FieldType([offset_provider.origin_axis], type_) + else: + assert arg_node.data in data_types + arg_type = data_types[arg_node.data] + if len(lambda_node.in_connectors) != 1 or offset is not None: # indirection tasklet with explicit indexes memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) elif isinstance(arg_type, ts.ScalarType): @@ -151,8 +161,8 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: # read one field element through memlet subset assert all(dim in self.field_domain for dim in arg_type.dims) subset = ",".join( - f"i_{dim.value} + ({offset})" - for dim, off in zip(arg_type.dims, offset, strict=True) + f"i_{dim.value} + ({off})" + for dim, off in zip(arg_type.dims, index_offset, strict=True) ) memlet = dace.Memlet(data=arg_node.data, subset=subset, volume=1) self.head_state.add_memlet_path( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 79625a15ee..96ffd2c70e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -28,10 +28,11 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, + connectivity_identifier, filter_connectivities, get_domain, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_specifications as ts, type_translation as tt class GTIRToSDFG(eve.NodeVisitor): @@ -86,7 +87,9 @@ def _make_array_shape_and_strides( strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i in range(len(dims))] return shape, strides - def _add_storage(self, sdfg: dace.SDFG, name: str, data_type: ts.DataType) -> None: + def _add_storage( + self, sdfg: dace.SDFG, name: str, data_type: ts.DataType, transient: bool = False + ) -> None: """ Add external storage (aka non-transient) for data containers passed as arguments to the SDFG. @@ -97,7 +100,7 @@ def _add_storage(self, sdfg: dace.SDFG, name: str, data_type: ts.DataType) -> No # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. sym_shape, sym_strides = self._make_array_shape_and_strides(name, data_type.dims) - sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) + sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=transient) elif isinstance(data_type, ts.ScalarType): dtype = as_dace_type(data_type) # scalar arguments passed to the program are represented as symbols in DaCe SDFG @@ -177,6 +180,19 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for param, type_ in zip(node.params, self.param_types, strict=True): self._add_storage(sdfg, str(param.id), type_) + # add SDFG storage for connectivity tables + for offset, offset_provider in filter_connectivities(self.offset_provider).items(): + scalar_kind = tt.get_scalar_kind(offset_provider.index_type) + local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) + # We store all connectivity tables as transient arrays here; later, while building + # the field operator expressions, we change to non transient the tables + # that are actually used. This way, we avoid adding SDFG arguments for + # the connectivity tabkes that are not used. + self._add_storage(sdfg, connectivity_identifier(offset), type_, transient=True) + # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): head_state = sdfg.add_state_after(head_state, f"stmt_{i}") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index e209177e0b..f60c3b5552 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -26,8 +26,12 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( MATH_BUILTINS_MAPPING, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.program_processors.runners.dace_fieldview.utility import ( + as_dace_type, + connectivity_identifier, + unique_name, +) +from gt4py.next.type_system import type_specifications as ts, type_translation as tt @dataclass(frozen=True) @@ -116,30 +120,30 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: ) else: index_connectors = [ - f"i_{dim}" + f"__i_{dim}" for dim, index in it.indices.items() if isinstance(index, TaskletExpr | ValueError) ] sorted_indices = [it.indices[dim] for dim in it.dimensions] index_internals = ",".join( - index.value if isinstance(index, SymbolExpr) else f"i_{dim}" + index.value if isinstance(index, SymbolExpr) else f"__i_{dim}" for dim, index in zip(it.dimensions, sorted_indices) ) deref_node = self.state.add_tasklet( "deref_field_indirecton", - set("field", *index_connectors), + {"field"} | set(index_connectors), {"val"}, code=f"val = field[{index_internals}]", ) for dim, index_expr in it.indices.items(): - deref_connector = f"i_{dim}" + deref_connector = f"__i_{dim}" if isinstance(index_expr, TaskletExpr): self.state.add_edge( index_expr.node, index_expr.connector, deref_node, deref_connector, - dace.Memlet(data=index_expr.node.data, subset="0"), + dace.Memlet(), ) elif isinstance(index_expr, ValueExpr): self.state.add_edge( @@ -207,7 +211,33 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices, it.dtype) else: # shift in unstructured domain by means of a neighbor table - raise NotImplementedError + origin_dim = offset_provider.origin_axis.value + assert origin_dim not in it.indices + neighbor_dim = offset_provider.neighbor_axis.value + assert neighbor_dim not in it.indices + dim_index = it.dimensions.index(neighbor_dim) + + shift_node = self.state.add_tasklet( + "shift", + {"table"}, + {"target_index"}, + f"target_index = table[i_{origin_dim}, {offset_value}]", + ) + + offset_table = connectivity_identifier(offset) + offset_table_node = self.state.add_access(offset_table) + self.input_connections.append((offset_table_node, (shift_node, "table", [0, 0]))) + + scalar_kind = tt.get_scalar_kind(offset_provider.index_type) + dtype = as_dace_type(ts.ScalarType(scalar_kind)) + + shifted_it = IteratorExpr( + it.field, + [origin_dim if i == dim_index else dim for i, dim in enumerate(it.dimensions)], + it.offset, + {origin_dim: TaskletExpr(shift_node, "target_index", dtype)}, + it.dtype, + ) return shifted_it diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 2e8a205ca8..8bef7e9be5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -42,6 +42,10 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") +def connectivity_identifier(name: str) -> str: + return f"__connectivity_{name}" + + def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: """ Filter offset providers of type `Connectivity`. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index c898ced98c..9c0b9c64b0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -18,7 +18,7 @@ """ from typing import Union -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Connectivity, Dimension, NeighborTable from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( @@ -30,14 +30,31 @@ import pytest -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + V2E, + Edge, + IDim, + MeshDescriptor, + Vertex, + simple_mesh, +) +from next_tests.integration_tests.cases import EField, IFloatField, VField dace = pytest.importorskip("dace") N = 10 -FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +SIMPLE_MESH: MeshDescriptor = simple_mesh() FSYMBOLS = dict( + __edges_size_0=SIMPLE_MESH.num_edges, + __edges_stride_0=1, + __vertices_size_0=SIMPLE_MESH.num_vertices, + __vertices_stride_0=1, + nedges=SIMPLE_MESH.num_edges, + nvertices=SIMPLE_MESH.num_vertices, __w_size_0=N, __w_stride_0=1, __x_size_0=N, @@ -48,7 +65,6 @@ __z_stride_0=1, size=N, ) -OFFSET_PROVIDERS: dict[str, Connectivity | Dimension] = {} def test_gtir_sum2(): @@ -79,7 +95,7 @@ def test_gtir_sum2(): c = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} ) sdfg = sdfg_genenerator.visit(testee) @@ -116,7 +132,8 @@ def test_gtir_sum2_sym(): b = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) @@ -195,7 +212,8 @@ def test_gtir_sum3(): d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider={}, ) for testee in [testee_fieldview, testee_inlined]: @@ -263,15 +281,15 @@ def test_gtir_select(): sdfg_genenerator = FieldviewGtirToSDFG( [ - FTYPE, - FTYPE, - FTYPE, - FTYPE, + IFTYPE, + IFTYPE, + IFTYPE, + IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.FLOAT64), ts.ScalarType(ts.ScalarKind.INT32), ], - OFFSET_PROVIDERS, + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) @@ -338,13 +356,13 @@ def test_gtir_select_nested(): sdfg_genenerator = FieldviewGtirToSDFG( [ - FTYPE, - FTYPE, + IFTYPE, + IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.INT32), ], - OFFSET_PROVIDERS, + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) @@ -383,7 +401,7 @@ def test_gtir_cartesian_shift(): b = np.empty(N) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim} + [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim} ) sdfg = sdfg_genenerator.visit(testee) @@ -393,3 +411,55 @@ def test_gtir_cartesian_shift(): FSYMBOLS_tmp["__x_size_0"] = N + 1 sdfg(x=a, y=b, **FSYMBOLS_tmp) assert np.allclose(a[1:] + 1, b) + + +def test_gtir_connectivity_shift(): + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices") + ) + testee = itir.Program( + id="connectivity_shift", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nedges"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", 1)("it"))), + vertex_domain, + ) + )("edges"), + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v = np.empty(SIMPLE_MESH.num_vertices) + + sdfg_genenerator = FieldviewGtirToSDFG( + [EFTYPE, VFTYPE, ts.ScalarType(ts.ScalarKind.INT32), ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider=SIMPLE_MESH.offset_provider, + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + sdfg( + edges=e, + vertices=v, + __connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + ____connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + ____connectivity_V2E_stride_1=1, + ) + assert np.allclose(v, e[connectivity_V2E.table[:, 1]]) From 738da278f3cb8875de95e2dc434312cbd5c7c7c6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 08:51:20 +0200 Subject: [PATCH 045/235] Add ConnectivityExpr type --- .../gtir_builtin_field_operator.py | 45 ++++++++-------- .../runners/dace_fieldview/gtir_to_tasklet.py | 52 +++++++++---------- .../runners_tests/test_dace_fieldview.py | 2 + 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index f9847ccaee..31081aa710 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -17,13 +17,14 @@ import dace -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( + ConnectivityExpr, GTIRToTasklet, IteratorExpr, SymbolExpr, @@ -32,10 +33,12 @@ ) from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, + connectivity_identifier, + filter_connectivities, get_domain, unique_name, ) -from gt4py.next.type_system import type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_specifications as ts # Define type of variables used for field indexing @@ -104,7 +107,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: else: assert isinstance(arg_type, ts.FieldType) dtype = as_dace_type(arg_type.dtype) - indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] = { + indices: dict[str, SymbolExpr | TaskletExpr | ConnectivityExpr] = { dim.value: SymbolExpr(f"i_{dim.value}", _INDEX_DTYPE) for dim in arg_type.dims if dim in self.field_domain @@ -118,6 +121,17 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) stencil_args.append(iterator_arg) + for offset, connectivity in filter_connectivities(self.offset_provider).items(): + table_name = connectivity_identifier(offset) + arg_type = ts.FieldType( + [ + connectivity.origin_axis, + Dimension(f"neighbor_{connectivity.neighbor_axis}", DimensionKind.LOCAL), + ], + _INDEX_DTYPE, + ) + data_types[table_name] = arg_type + # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) @@ -140,28 +154,17 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: me, mx = self.head_state.add_map(unique_name("map"), map_ranges) for arg_node, (lambda_node, lambda_connector, index_offset) in input_connections: - offset = None - if arg_node.data.startswith("__connectivity"): - self.sdfg.arrays[arg_node.data].transient = False - offset = arg_node.data.removeprefix("__connectivity") - offset_provider = self.offset_provider[offset] - assert isinstance(offset_provider, Connectivity) - type_ = tt.from_type_hint(offset_provider.index_type) - assert isinstance(type_, ts.ScalarType) - arg_type = ts.FieldType([offset_provider.origin_axis], type_) - else: - assert arg_node.data in data_types - arg_type = data_types[arg_node.data] - if len(lambda_node.in_connectors) != 1 or offset is not None: - # indirection tasklet with explicit indexes - memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) - elif isinstance(arg_type, ts.ScalarType): + assert arg_node.data in data_types + arg_type = data_types[arg_node.data] + if isinstance(arg_type, ts.ScalarType): memlet = dace.Memlet(data=arg_node.data, subset="0", volume=1) + elif lambda_node.label == "deref_field_indirection" and lambda_connector == "field": + # indirection tasklet with explicit indexes besides the field argument + memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) else: # read one field element through memlet subset - assert all(dim in self.field_domain for dim in arg_type.dims) subset = ",".join( - f"i_{dim.value} + ({off})" + f"{off}" if dim.kind == DimensionKind.LOCAL else f"i_{dim.value} + ({off})" for dim, off in zip(arg_type.dims, index_offset, strict=True) ) memlet = dace.Memlet(data=arg_node.data, subset=subset, volume=1) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index f60c3b5552..04bd787022 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -31,7 +31,7 @@ connectivity_identifier, unique_name, ) -from gt4py.next.type_system import type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) @@ -59,6 +59,15 @@ class ValueExpr: dtype: dace.typeclass +@dataclass(frozen=True) +class ConnectivityExpr: + """Provides access to connectivity table by means of memlet.""" + + table: dace.nodes.AccessNode + # TODO: should value be `int | str` instead? + value: int + + @dataclass(frozen=True) class IteratorExpr: """Iterator to access the field provided by an array access node.""" @@ -66,7 +75,7 @@ class IteratorExpr: field: dace.nodes.AccessNode dimensions: list[str] offset: list[int] - indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] + indices: dict[str, SymbolExpr | TaskletExpr | ConnectivityExpr] dtype: dace.typeclass @@ -119,10 +128,11 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: "deref_field", {"field"}, {"val"}, code="val = field" ) else: + assert all(dim in it.dimensions for dim in it.indices.keys()) index_connectors = [ f"__i_{dim}" for dim, index in it.indices.items() - if isinstance(index, TaskletExpr | ValueError) + if isinstance(index, TaskletExpr | ConnectivityExpr) ] sorted_indices = [it.indices[dim] for dim in it.dimensions] index_internals = ",".join( @@ -130,7 +140,7 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: for dim, index in zip(it.dimensions, sorted_indices) ) deref_node = self.state.add_tasklet( - "deref_field_indirecton", + "deref_field_indirection", {"field"} | set(index_connectors), {"val"}, code=f"val = field[{index_internals}]", @@ -145,13 +155,9 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: deref_connector, dace.Memlet(), ) - elif isinstance(index_expr, ValueExpr): - self.state.add_edge( - index_expr.value, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.value.data, subset="0"), + elif isinstance(index_expr, ConnectivityExpr): + self.input_connections.append( + (index_expr.table, (deref_node, deref_connector, [0, index_expr.value])) ) else: assert isinstance(index_expr, SymbolExpr) @@ -214,28 +220,18 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: origin_dim = offset_provider.origin_axis.value assert origin_dim not in it.indices neighbor_dim = offset_provider.neighbor_axis.value - assert neighbor_dim not in it.indices - dim_index = it.dimensions.index(neighbor_dim) - - shift_node = self.state.add_tasklet( - "shift", - {"table"}, - {"target_index"}, - f"target_index = table[i_{origin_dim}, {offset_value}]", - ) - + assert neighbor_dim in it.dimensions offset_table = connectivity_identifier(offset) + # initially, the storage for the connectivty tables is created as transient + # when the tables are used, the storage is changed to non-transient, + # so the corresponding arrays are supposed to be allocated by the SDFG caller + self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) - self.input_connections.append((offset_table_node, (shift_node, "table", [0, 0]))) - - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - dtype = as_dace_type(ts.ScalarType(scalar_kind)) - shifted_it = IteratorExpr( it.field, - [origin_dim if i == dim_index else dim for i, dim in enumerate(it.dimensions)], + [origin_dim if dim == neighbor_dim else dim for dim in it.dimensions], it.offset, - {origin_dim: TaskletExpr(shift_node, "target_index", dtype)}, + {origin_dim: ConnectivityExpr(offset_table_node, offset_value)}, it.dtype, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 9c0b9c64b0..7417fd0d6f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -459,6 +459,8 @@ def test_gtir_connectivity_shift(): vertices=v, __connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, + ____connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, + ____connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, ____connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, ____connectivity_V2E_stride_1=1, ) From e5494d892ed5c780915b521df69da0427cc51bf5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 10:13:17 +0200 Subject: [PATCH 046/235] Remove ConnectivityExpr type, use ValueExpr instead --- .../gtir_builtin_field_operator.py | 50 +++---- .../runners/dace_fieldview/gtir_to_tasklet.py | 129 ++++++++++-------- .../runners/dace_fieldview/utility.py | 2 +- .../runners_tests/test_dace_fieldview.py | 10 +- 4 files changed, 96 insertions(+), 95 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 31081aa710..2805312037 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -17,14 +17,13 @@ import dace -from gt4py.next.common import Connectivity, Dimension, DimensionKind +from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( - ConnectivityExpr, GTIRToTasklet, IteratorExpr, SymbolExpr, @@ -33,8 +32,6 @@ ) from gt4py.next.program_processors.runners.dace_fieldview.utility import ( as_dace_type, - connectivity_identifier, - filter_connectivities, get_domain, unique_name, ) @@ -89,8 +86,8 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + dimension_index_fmt = "i_{dim}" # first visit the list of arguments and build a symbol map - data_types: dict[str, ts.FieldType | ts.ScalarType] = {} stencil_args: list[IteratorExpr | ValueExpr] = [] for arg in self.stencil_args: arg_nodes = arg() @@ -98,19 +95,17 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: arg_node, arg_type = arg_nodes[0] # require all argument nodes to be data access nodes (no symbols) assert isinstance(arg_node, dace.nodes.AccessNode) - data_types[arg_node.data] = arg_type if isinstance(arg_type, ts.ScalarType): dtype = as_dace_type(arg_type) - scalar_arg = ValueExpr(arg_node, dtype) + scalar_arg = ValueExpr(arg_node, [0], dtype) stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) dtype = as_dace_type(arg_type.dtype) - indices: dict[str, SymbolExpr | TaskletExpr | ConnectivityExpr] = { + indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] = { dim.value: SymbolExpr(f"i_{dim.value}", _INDEX_DTYPE) - for dim in arg_type.dims - if dim in self.field_domain + for dim in self.field_domain.keys() } iterator_arg = IteratorExpr( arg_node, @@ -121,17 +116,6 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ) stencil_args.append(iterator_arg) - for offset, connectivity in filter_connectivities(self.offset_provider).items(): - table_name = connectivity_identifier(offset) - arg_type = ts.FieldType( - [ - connectivity.origin_axis, - Dimension(f"neighbor_{connectivity.neighbor_axis}", DimensionKind.LOCAL), - ], - _INDEX_DTYPE, - ) - data_types[table_name] = arg_type - # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) @@ -146,28 +130,26 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: field_node = self.add_local_storage(self.field_type, field_shape) # assume tasklet with single output - output_index = ",".join(f"i_{dim.value}" for dim in self.field_type.dims) + output_index = ",".join( + dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims + ) output_memlet = dace.Memlet(data=field_node.data, subset=output_index) # create map range corresponding to the field operator domain - map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} + map_ranges = { + dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" + for dim, (lb, ub) in self.field_domain.items() + } me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for arg_node, (lambda_node, lambda_connector, index_offset) in input_connections: - assert arg_node.data in data_types - arg_type = data_types[arg_node.data] - if isinstance(arg_type, ts.ScalarType): - memlet = dace.Memlet(data=arg_node.data, subset="0", volume=1) - elif lambda_node.label == "deref_field_indirection" and lambda_connector == "field": + for arg_node, lambda_node, lambda_connector, data_index in input_connections: + if lambda_node.label == "deref_field_indirection" and lambda_connector == "field": # indirection tasklet with explicit indexes besides the field argument memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) else: # read one field element through memlet subset - subset = ",".join( - f"{off}" if dim.kind == DimensionKind.LOCAL else f"i_{dim.value} + ({off})" - for dim, off in zip(arg_type.dims, index_offset, strict=True) - ) - memlet = dace.Memlet(data=arg_node.data, subset=subset, volume=1) + data_subset = ",".join(str(index) for index in data_index) + memlet = dace.Memlet(data=arg_node.data, subset=data_subset, volume=1) self.head_state.add_memlet_path( arg_node, me, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 04bd787022..020a9ca8f3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -31,7 +31,7 @@ connectivity_identifier, unique_name, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_specifications as ts, type_translation as tt @dataclass(frozen=True) @@ -55,27 +55,19 @@ class TaskletExpr: class ValueExpr: """Data provided by a scalar access node.""" - value: dace.nodes.AccessNode + data: dace.nodes.AccessNode + subset: list[dace.symbolic.SymbolicType] dtype: dace.typeclass -@dataclass(frozen=True) -class ConnectivityExpr: - """Provides access to connectivity table by means of memlet.""" - - table: dace.nodes.AccessNode - # TODO: should value be `int | str` instead? - value: int - - @dataclass(frozen=True) class IteratorExpr: """Iterator to access the field provided by an array access node.""" field: dace.nodes.AccessNode dimensions: list[str] - offset: list[int] - indices: dict[str, SymbolExpr | TaskletExpr | ConnectivityExpr] + offset: list[dace.symbolic.SymbolicType] + indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] dtype: dace.typeclass @@ -84,7 +76,9 @@ class GTIRToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[tuple[dace.nodes.AccessNode, tuple[dace.nodes.Tasklet, str, list[int]]]] + input_connections: list[ + tuple[dace.nodes.AccessNode, dace.nodes.Tasklet, str, list[dace.symbolic.SymbolicType]] + ] offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | SymbolExpr | ValueExpr] @@ -100,7 +94,7 @@ def __init__( self.offset_provider = offset_provider self.symbol_map = {} - def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: + def _visit_deref(self, node: itir.FunCall) -> TaskletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -110,33 +104,30 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: deref_node = self.state.add_tasklet( "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" ) - deref_expr = TaskletExpr(deref_node, "val", it.dtype) - - elif isinstance(it, ValueExpr): - deref_node = self.state.add_tasklet( - "deref_scalar", {"scalar"}, {"val"}, code="val = scalar" - ) - deref_expr = TaskletExpr(deref_node, "val", it.dtype) + return TaskletExpr(deref_node, "val", it.dtype) - # add new termination point for the data access - self.input_connections.append((it.value, (deref_node, "scalar", []))) - - else: + elif isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # use direct field access through memlet subset - deref_node = self.state.add_tasklet( - "deref_field", {"field"}, {"val"}, code="val = field" - ) + data_index = [ + dace.symbolic.SymExpr(it.indices[dim].value) + off # type: ignore[union-attr] + for dim, off in zip(it.dimensions, it.offset, strict=True) + ] + return ValueExpr(it.field, data_index, it.dtype) + else: - assert all(dim in it.dimensions for dim in it.indices.keys()) + input_connector_fmt = "__inp_{dim}" + assert all(dim in it.indices.keys() for dim in it.dimensions) index_connectors = [ - f"__i_{dim}" + input_connector_fmt.format(dim=dim) for dim, index in it.indices.items() - if isinstance(index, TaskletExpr | ConnectivityExpr) + if not isinstance(index, SymbolExpr) ] sorted_indices = [it.indices[dim] for dim in it.dimensions] index_internals = ",".join( - index.value if isinstance(index, SymbolExpr) else f"__i_{dim}" + index.value + if isinstance(index, SymbolExpr) + else input_connector_fmt.format(dim=dim) for dim, index in zip(it.dimensions, sorted_indices) ) deref_node = self.state.add_tasklet( @@ -145,8 +136,11 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: {"val"}, code=f"val = field[{index_internals}]", ) + # add new termination point for this field parameter + self.input_connections.append((it.field, deref_node, "field", it.offset)) + for dim, index_expr in it.indices.items(): - deref_connector = f"__i_{dim}" + deref_connector = input_connector_fmt.format(dim=dim) if isinstance(index_expr, TaskletExpr): self.state.add_edge( index_expr.node, @@ -155,19 +149,23 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr: deref_connector, dace.Memlet(), ) - elif isinstance(index_expr, ConnectivityExpr): + elif isinstance(index_expr, ValueExpr): self.input_connections.append( - (index_expr.table, (deref_node, deref_connector, [0, index_expr.value])) + ( + index_expr.data, + deref_node, + deref_connector, + index_expr.subset, + ) ) else: assert isinstance(index_expr, SymbolExpr) - deref_expr = TaskletExpr(deref_node, "val", it.dtype) + return TaskletExpr(deref_node, "val", it.dtype) - # add new termination point for this field parameter - self.input_connections.append((it.field, (deref_node, "field", it.offset))) - - return deref_expr + else: + assert isinstance(it, ValueExpr) + return it def _split_shift_args( self, args: list[itir.Expr] @@ -218,10 +216,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: else: # shift in unstructured domain by means of a neighbor table origin_dim = offset_provider.origin_axis.value - assert origin_dim not in it.indices + assert origin_dim in it.indices + origin_index = it.indices[origin_dim] + assert isinstance(origin_index, SymbolExpr) neighbor_dim = offset_provider.neighbor_axis.value assert neighbor_dim in it.dimensions offset_table = connectivity_identifier(offset) + index_scalar_type = ts.ScalarType(tt.get_scalar_kind(offset_provider.index_type)) + offset_dtype = as_dace_type(index_scalar_type) # initially, the storage for the connectivty tables is created as transient # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller @@ -231,13 +233,17 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: it.field, [origin_dim if dim == neighbor_dim else dim for dim in it.dimensions], it.offset, - {origin_dim: ConnectivityExpr(offset_table_node, offset_value)}, + { + origin_dim: ValueExpr( + offset_table_node, [origin_index.value, offset_value], offset_dtype + ) + }, it.dtype, ) return shifted_it - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr: + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | ValueExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -248,18 +254,18 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr: assert isinstance(node.fun, itir.SymRef) node_internals = [] - node_connections = {} + node_connections: dict[str, TaskletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, SymbolExpr): - # use the argument value without adding any connector - node_internals.append(arg_expr.value) - else: - assert isinstance(arg_expr, TaskletExpr) - # the argument value is the result of a tasklet node + if isinstance(arg_expr, TaskletExpr | ValueExpr): + # the argument value is the result of a tasklet node or direct field access connector = f"__inp_{i}" node_connections[connector] = arg_expr node_internals.append(connector) + else: + assert isinstance(arg_expr, SymbolExpr) + # use the argument value without adding any connector + node_internals.append(arg_expr.value) # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) @@ -281,9 +287,14 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr: ) for connector, arg_expr in node_connections.items(): - self.state.add_edge( - arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() - ) + if isinstance(arg_expr, TaskletExpr): + self.state.add_edge( + arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() + ) + else: + self.input_connections.append( + (arg_expr.data, tasklet_node, connector, arg_expr.subset) + ) dtype = as_dace_type(node_type) return TaskletExpr(tasklet_node, "result", dtype) @@ -291,7 +302,15 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr: def visit_Lambda( self, node: itir.Lambda, args: list[IteratorExpr | SymbolExpr | ValueExpr] ) -> tuple[ - list[tuple[dace.nodes.AccessNode, tuple[dace.nodes.Tasklet, str, list[int]]]], TaskletExpr + list[ + tuple[ + dace.nodes.AccessNode, + dace.nodes.Tasklet, + str, + list[dace.symbolic.SymbolicType], + ] + ], + TaskletExpr, ]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 8bef7e9be5..e76f564526 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -43,7 +43,7 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: def connectivity_identifier(name: str) -> str: - return f"__connectivity_{name}" + return f"connectivity_{name}" def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 7417fd0d6f..3651eda82a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -457,11 +457,11 @@ def test_gtir_connectivity_shift(): sdfg( edges=e, vertices=v, - __connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - ____connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, - ____connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - ____connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - ____connectivity_V2E_stride_1=1, + __connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, + __connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_1=1, ) assert np.allclose(v, e[connectivity_V2E.table[:, 1]]) From e9455e3f42f004db718d1aae243dcf80dcaa7b17 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 16:10:51 +0200 Subject: [PATCH 047/235] Changes in preparation for shift builtin --- .../gtir_builtin_field_operator.py | 93 ++++---- .../runners/dace_fieldview/gtir_to_tasklet.py | 200 +++++++++++------- 2 files changed, 178 insertions(+), 115 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 2191277cd2..09e2d4facb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -16,6 +16,7 @@ from typing import Callable, TypeAlias import dace +import dace.subsets as sbs from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir @@ -25,12 +26,19 @@ ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( GTIRToTasklet, - ValueExpr, + IteratorExpr, + MemletExpr, + SymbolExpr, + TaskletExpr, ) from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name from gt4py.next.type_system import type_specifications as ts +# Define type of variables used for field indexing +_INDEX_DTYPE = dace.int64 + + class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" @@ -75,11 +83,40 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: - # generate a tasklet node implementing the stencil function and represent - # the field operator as a mapped tasklet, which will range over the field domain + dimension_index_fmt = "i_{dim}" + # first visit the list of arguments and build a symbol map + stencil_args: list[IteratorExpr | MemletExpr] = [] + for arg in self.stencil_args: + arg_nodes = arg() + assert len(arg_nodes) == 1 + data_node, arg_type = arg_nodes[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + scalar_arg = MemletExpr(data_node, sbs.Indices([0])) + stencil_args.append(scalar_arg) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = { + dim.value: SymbolExpr( + dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), + _INDEX_DTYPE, + ) + for dim in self.field_domain.keys() + } + iterator_arg = IteratorExpr( + data_node, + [dim.value for dim in arg_type.dims], + sbs.Indices([0] * len(arg_type.dims)), + indices, + ) + stencil_args.append(iterator_arg) + + # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr) - assert isinstance(output_expr, ValueExpr) + input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + assert isinstance(output_expr, TaskletExpr) # allocate local temporary storage for the result field field_shape = [ @@ -89,47 +126,27 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: ] field_node = self.add_local_storage(self.field_type, field_shape) - data_nodes: dict[str, dace.nodes.AccessNode] = {} - input_memlets: dict[str, dace.Memlet] = {} - for arg, param in zip(self.stencil_args, self.stencil_expr.params, strict=True): - arg_nodes = arg() - assert len(arg_nodes) == 1 - arg_node, arg_type = arg_nodes[0] - data = str(param.id) - # require (for now) all input nodes to be data access nodes - assert isinstance(arg_node, dace.nodes.AccessNode) - data_nodes[data] = arg_node - if isinstance(arg_type, ts.FieldType): - # support either single element access (general case) or full array shape - is_scalar = all(dim in self.field_domain for dim in arg_type.dims) - if is_scalar: - subset = ",".join(f"i_{dim.value}" for dim in arg_type.dims) - input_memlets[data] = dace.Memlet(data=arg_node.data, subset=subset, volume=1) - else: - memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) - # set volume to 1 because the stencil function always performs single element access - # TODO: check validity of this assumption - memlet.volume = 1 - input_memlets[data] = memlet - else: - input_memlets[data] = dace.Memlet(data=arg_node.data, subset="0") - # assume tasklet with single output - output_index = ",".join(f"i_{dim.value}" for dim in self.field_type.dims) + output_index = ",".join( + dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims + ) output_memlet = dace.Memlet(data=field_node.data, subset=output_index) # create map range corresponding to the field operator domain - map_ranges = {f"i_{dim.value}": f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items()} + map_ranges = { + dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" + for dim, (lb, ub) in self.field_domain.items() + } me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for input_expr, input_param in input_connections: - assert input_param in data_nodes + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1) self.head_state.add_memlet_path( - data_nodes[input_param], + data_node, me, - input_expr.node, - dst_conn=input_expr.connector, - memlet=input_memlets[input_param], + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, ) self.head_state.add_memlet_path( output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index f538198eb9..dafb61bc34 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -14,8 +14,10 @@ from dataclasses import dataclass +from typing import TypeAlias import dace +import dace.subsets as sbs from gt4py import eve from gt4py.next.common import Connectivity, Dimension @@ -28,34 +30,55 @@ @dataclass(frozen=True) -class LiteralExpr: - """Any symbolic expression that can be evaluated at compile time.""" +class MemletExpr: + """Scalar or array data access thorugh a memlet.""" - value: dace.symbolic.SymbolicType + data: dace.nodes.AccessNode + subset: sbs.Indices | sbs.Range @dataclass(frozen=True) class SymbolExpr: - """The data access to a scalar or field through a symbolic reference.""" + """Any symbolic expression that is constant in the context of current SDFG.""" - data: str + value: dace.symbolic.SymExpr + dtype: dace.typeclass @dataclass(frozen=True) -class ValueExpr: - """The result of a computation provided by a tasklet node.""" +class TaskletExpr: + """Result of the computation provided by a tasklet node.""" node: dace.nodes.Tasklet connector: str +@dataclass(frozen=True) +class IteratorExpr: + """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" + + field: dace.nodes.AccessNode + dimensions: list[str] + offset: list[dace.symbolic.SymExpr] + indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] + + +InputConnection: TypeAlias = tuple[ + dace.nodes.AccessNode, + sbs.Range, + dace.nodes.Tasklet, + str, +] + + class GTIRToTasklet(eve.NodeVisitor): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[tuple[ValueExpr, str]] + input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] + symbol_map: dict[str, SymbolExpr | IteratorExpr | MemletExpr] def __init__( self, @@ -67,87 +90,110 @@ def __init__( self.state = state self.input_connections = [] self.offset_provider = offset_provider + self.symbol_map = {} - def _visit_shift(self, node: itir.FunCall) -> str: - assert len(node.args) == 2 - raise NotImplementedError + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: + assert len(node.args) == 1 + it = self.visit(node.args[0]) + + if isinstance(it, SymbolExpr): + cast_sym = str(it.dtype) + cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] + deref_node = self.state.add_tasklet( + "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" + ) + return TaskletExpr(deref_node, "val") + + elif isinstance(it, IteratorExpr): + if all(isinstance(index, SymbolExpr) for index in it.indices.values()): + # use direct field access through memlet subset + data_index = sbs.Indices( + [ + it.indices[dim].value + off # type: ignore[union-attr] + for dim, off in zip(it.dimensions, it.offset, strict=True) + ] + ) + return MemletExpr(it.field, data_index) - def visit_FunCall(self, node: itir.FunCall) -> ValueExpr | SymbolExpr: - inp_tasklets = {} - inp_symbols = set() - inp_connectors = set() - node_internals = [] - for i, arg in enumerate(node.args): - arg_expr = self.visit(arg) - if isinstance(arg_expr, LiteralExpr): - # use the value without adding any connector - node_internals.append(arg_expr.value) else: - if isinstance(arg_expr, ValueExpr): - # the value is the result of a tasklet node - connector = f"__inp_{i}" - inp_tasklets[connector] = arg_expr - else: - # the value is the result of a tasklet node - assert isinstance(arg_expr, SymbolExpr) - connector = f"__inp_{arg_expr.data}" - inp_symbols.add((connector, arg_expr.data)) - inp_connectors.add(connector) - node_internals.append(connector) + raise NotImplementedError + + else: + assert isinstance(it, MemletExpr) + return it + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: + raise NotImplementedError + + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: if cpm.is_call_to(node, "deref"): - assert len(inp_tasklets) == 0 - assert len(inp_symbols) == 1 - _, data = inp_symbols.pop() - return SymbolExpr(data) + return self._visit_deref(node) elif cpm.is_call_to(node.fun, "shift"): - code = self._visit_shift(node.fun) - - elif isinstance(node.fun, itir.SymRef): - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - - out_connector = "__out" - tasklet_node = self.state.add_tasklet( - unique_name("tasklet"), - inp_connectors, - {out_connector}, - "{} = {}".format(out_connector, code), - ) + return self._visit_shift(node) else: - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + assert isinstance(node.fun, itir.SymRef) - for input_conn, inp_expr in inp_tasklets.items(): - self.state.add_edge( - inp_expr.node, inp_expr.connector, tasklet_node, input_conn, dace.Memlet() - ) - self.input_connections.extend( - (ValueExpr(tasklet_node, connector), data) for connector, data in inp_symbols + node_internals = [] + node_connections: dict[str, MemletExpr | TaskletExpr] = {} + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, MemletExpr | TaskletExpr): + # the argument value is the result of a tasklet node or direct field access + connector = f"__inp_{i}" + node_connections[connector] = arg_expr + node_internals.append(connector) + else: + assert isinstance(arg_expr, SymbolExpr) + # use the argument value without adding any connector + node_internals.append(arg_expr.value) + + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] + code = fmt.format(*node_internals) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + + out_connector = "result" + tasklet_node = self.state.add_tasklet( + unique_name("tasklet"), + node_connections.keys(), + {out_connector}, + "{} = {}".format(out_connector, code), ) - return ValueExpr(tasklet_node, out_connector) - def visit_Lambda(self, node: itir.Lambda) -> tuple[list[tuple[ValueExpr, str]], ValueExpr]: - assert len(self.input_connections) == 0 + for connector, arg_expr in node_connections.items(): + if isinstance(arg_expr, TaskletExpr): + self.state.add_edge( + arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() + ) + else: + self.input_connections.append( + (arg_expr.data, arg_expr.subset, tasklet_node, connector) + ) + + return TaskletExpr(tasklet_node, "result") + + def visit_Lambda( + self, node: itir.Lambda, args: list[SymbolExpr | IteratorExpr | MemletExpr] + ) -> tuple[ + list[InputConnection], + TaskletExpr, + ]: + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg output_expr = self.visit(node.expr) - assert isinstance(output_expr, ValueExpr) + assert isinstance(output_expr, TaskletExpr) return self.input_connections, output_expr - def visit_Literal(self, node: itir.Literal) -> LiteralExpr: - cast_sym = str(as_dace_type(node.type)) - cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] - typed_value = cast_fmt.format(node.value) - return LiteralExpr(typed_value) - - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr: - """ - Symbol references are mapped to tasklet connectors that access some kind of data. - """ - sym_name = str(node.id) - return SymbolExpr(sym_name) + def visit_Literal(self, node: itir.Literal) -> SymbolExpr: + dtype = as_dace_type(node.type) + return SymbolExpr(node.value, dtype) + + def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr | IteratorExpr | MemletExpr: + param = str(node.id) + assert param in self.symbol_map + return self.symbol_map[param] From cbf55deb90d49331ad5f82fd77e6bb1bd914e72a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 15:06:24 +0200 Subject: [PATCH 048/235] Refactoring --- .../gtir_builtin_field_operator.py | 43 +++--- .../runners/dace_fieldview/gtir_to_tasklet.py | 133 +++++++++--------- 2 files changed, 80 insertions(+), 96 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 2805312037..09e2d4facb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -16,6 +16,7 @@ from typing import Callable, TypeAlias import dace +import dace.subsets as sbs from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir @@ -26,15 +27,11 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( GTIRToTasklet, IteratorExpr, + MemletExpr, SymbolExpr, TaskletExpr, - ValueExpr, -) -from gt4py.next.program_processors.runners.dace_fieldview.utility import ( - as_dace_type, - get_domain, - unique_name, ) +from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name from gt4py.next.type_system import type_specifications as ts @@ -88,31 +85,31 @@ def __init__( def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: dimension_index_fmt = "i_{dim}" # first visit the list of arguments and build a symbol map - stencil_args: list[IteratorExpr | ValueExpr] = [] + stencil_args: list[IteratorExpr | MemletExpr] = [] for arg in self.stencil_args: arg_nodes = arg() assert len(arg_nodes) == 1 - arg_node, arg_type = arg_nodes[0] + data_node, arg_type = arg_nodes[0] # require all argument nodes to be data access nodes (no symbols) - assert isinstance(arg_node, dace.nodes.AccessNode) + assert isinstance(data_node, dace.nodes.AccessNode) if isinstance(arg_type, ts.ScalarType): - dtype = as_dace_type(arg_type) - scalar_arg = ValueExpr(arg_node, [0], dtype) + scalar_arg = MemletExpr(data_node, sbs.Indices([0])) stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - dtype = as_dace_type(arg_type.dtype) - indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] = { - dim.value: SymbolExpr(f"i_{dim.value}", _INDEX_DTYPE) + indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = { + dim.value: SymbolExpr( + dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), + _INDEX_DTYPE, + ) for dim in self.field_domain.keys() } iterator_arg = IteratorExpr( - arg_node, + data_node, [dim.value for dim in arg_type.dims], - [0] * len(arg_type.dims), + sbs.Indices([0] * len(arg_type.dims)), indices, - dtype, ) stencil_args.append(iterator_arg) @@ -142,16 +139,10 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: } me, mx = self.head_state.add_map(unique_name("map"), map_ranges) - for arg_node, lambda_node, lambda_connector, data_index in input_connections: - if lambda_node.label == "deref_field_indirection" and lambda_connector == "field": - # indirection tasklet with explicit indexes besides the field argument - memlet = dace.Memlet.from_array(arg_node.data, arg_node.desc(self.sdfg)) - else: - # read one field element through memlet subset - data_subset = ",".join(str(index) for index in data_index) - memlet = dace.Memlet(data=arg_node.data, subset=data_subset, volume=1) + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1) self.head_state.add_memlet_path( - arg_node, + data_node, me, lambda_node, dst_conn=lambda_connector, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 020a9ca8f3..62579ff721 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -15,9 +15,10 @@ import itertools from dataclasses import dataclass -from typing import Optional +from typing import Optional, TypeAlias import dace +import dace.subsets as sbs from gt4py import eve from gt4py.next.common import Connectivity, Dimension @@ -31,14 +32,21 @@ connectivity_identifier, unique_name, ) -from gt4py.next.type_system import type_specifications as ts, type_translation as tt + + +@dataclass(frozen=True) +class MemletExpr: + """Scalar or array data access thorugh a memlet.""" + + data: dace.nodes.AccessNode + subset: sbs.Indices | sbs.Range @dataclass(frozen=True) class SymbolExpr: """Any symbolic expression that is constant in the context of current SDFG.""" - value: dace.symbolic.SymbolicType + value: dace.symbolic.SymExpr dtype: dace.typeclass @@ -48,27 +56,24 @@ class TaskletExpr: node: dace.nodes.Tasklet connector: str - dtype: dace.typeclass - - -@dataclass(frozen=True) -class ValueExpr: - """Data provided by a scalar access node.""" - - data: dace.nodes.AccessNode - subset: list[dace.symbolic.SymbolicType] - dtype: dace.typeclass @dataclass(frozen=True) class IteratorExpr: - """Iterator to access the field provided by an array access node.""" + """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode dimensions: list[str] - offset: list[dace.symbolic.SymbolicType] - indices: dict[str, SymbolExpr | TaskletExpr | ValueExpr] - dtype: dace.typeclass + offset: list[dace.symbolic.SymExpr] + indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] + + +InputConnection: TypeAlias = tuple[ + dace.nodes.AccessNode, + sbs.Range, + dace.nodes.Tasklet, + str, +] class GTIRToTasklet(eve.NodeVisitor): @@ -76,11 +81,9 @@ class GTIRToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[ - tuple[dace.nodes.AccessNode, dace.nodes.Tasklet, str, list[dace.symbolic.SymbolicType]] - ] + input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] - symbol_map: dict[str, IteratorExpr | SymbolExpr | ValueExpr] + symbol_map: dict[str, SymbolExpr | IteratorExpr | MemletExpr] def __init__( self, @@ -94,7 +97,7 @@ def __init__( self.offset_provider = offset_provider self.symbol_map = {} - def _visit_deref(self, node: itir.FunCall) -> TaskletExpr | ValueExpr: + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -104,16 +107,18 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr | ValueExpr: deref_node = self.state.add_tasklet( "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" ) - return TaskletExpr(deref_node, "val", it.dtype) + return TaskletExpr(deref_node, "val") elif isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # use direct field access through memlet subset - data_index = [ - dace.symbolic.SymExpr(it.indices[dim].value) + off # type: ignore[union-attr] - for dim, off in zip(it.dimensions, it.offset, strict=True) - ] - return ValueExpr(it.field, data_index, it.dtype) + data_index = sbs.Indices( + [ + it.indices[dim].value + off # type: ignore[union-attr] + for dim, off in zip(it.dimensions, it.offset, strict=True) + ] + ) + return MemletExpr(it.field, data_index) else: input_connector_fmt = "__inp_{dim}" @@ -137,34 +142,36 @@ def _visit_deref(self, node: itir.FunCall) -> TaskletExpr | ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for this field parameter - self.input_connections.append((it.field, deref_node, "field", it.offset)) + field_desc = it.field.desc(self.sdfg) + field_fullset = sbs.Range.from_array(field_desc) + self.input_connections.append((it.field, field_fullset, deref_node, "field")) for dim, index_expr in it.indices.items(): deref_connector = input_connector_fmt.format(dim=dim) - if isinstance(index_expr, TaskletExpr): - self.state.add_edge( - index_expr.node, - index_expr.connector, - deref_node, - deref_connector, - dace.Memlet(), - ) - elif isinstance(index_expr, ValueExpr): + if isinstance(index_expr, MemletExpr): self.input_connections.append( ( index_expr.data, + index_expr.subset, deref_node, deref_connector, - index_expr.subset, ) ) + elif isinstance(index_expr, TaskletExpr): + self.state.add_edge( + index_expr.node, + index_expr.connector, + deref_node, + deref_connector, + dace.Memlet(), + ) else: assert isinstance(index_expr, SymbolExpr) - return TaskletExpr(deref_node, "val", it.dtype) + return TaskletExpr(deref_node, "val") else: - assert isinstance(it, ValueExpr) + assert isinstance(it, MemletExpr) return it def _split_shift_args( @@ -200,19 +207,18 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_provider = self.offset_provider[offset] # second argument should be the offset value if isinstance(head[1], itir.OffsetLiteral): + assert isinstance(head[1].value, int) offset_value = head[1].value - assert isinstance(offset_value, int) else: raise NotImplementedError("Dynamic offset not supported.") if isinstance(offset_provider, Dimension): # cartesian offset along one dimension - dim_index = it.dimensions.index(offset_provider.value) new_offset = [ - prev_offset + offset_value if i == dim_index else prev_offset - for i, prev_offset in enumerate(it.offset) + prev_offset + offset_value if dim == offset_provider.value else prev_offset + for dim, prev_offset in zip(it.dimensions, it.offset, strict=True) ] - shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices, it.dtype) + shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices) else: # shift in unstructured domain by means of a neighbor table origin_dim = offset_provider.origin_axis.value @@ -222,8 +228,6 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: neighbor_dim = offset_provider.neighbor_axis.value assert neighbor_dim in it.dimensions offset_table = connectivity_identifier(offset) - index_scalar_type = ts.ScalarType(tt.get_scalar_kind(offset_provider.index_type)) - offset_dtype = as_dace_type(index_scalar_type) # initially, the storage for the connectivty tables is created as transient # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller @@ -234,16 +238,16 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: [origin_dim if dim == neighbor_dim else dim for dim in it.dimensions], it.offset, { - origin_dim: ValueExpr( - offset_table_node, [origin_index.value, offset_value], offset_dtype + origin_dim: MemletExpr( + offset_table_node, + sbs.Indices([origin_index.value, offset_value]), ) }, - it.dtype, ) return shifted_it - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | ValueExpr: + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -254,10 +258,10 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Valu assert isinstance(node.fun, itir.SymRef) node_internals = [] - node_connections: dict[str, TaskletExpr | ValueExpr] = {} + node_connections: dict[str, MemletExpr | TaskletExpr] = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, TaskletExpr | ValueExpr): + if isinstance(arg_expr, MemletExpr | TaskletExpr): # the argument value is the result of a tasklet node or direct field access connector = f"__inp_{i}" node_connections[connector] = arg_expr @@ -267,9 +271,6 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Valu # use the argument value without adding any connector node_internals.append(arg_expr.value) - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - # create a tasklet node implementing the builtin function builtin_name = str(node.fun.id) if builtin_name in MATH_BUILTINS_MAPPING: @@ -293,23 +294,15 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Valu ) else: self.input_connections.append( - (arg_expr.data, tasklet_node, connector, arg_expr.subset) + (arg_expr.data, arg_expr.subset, tasklet_node, connector) ) - dtype = as_dace_type(node_type) - return TaskletExpr(tasklet_node, "result", dtype) + return TaskletExpr(tasklet_node, "result") def visit_Lambda( - self, node: itir.Lambda, args: list[IteratorExpr | SymbolExpr | ValueExpr] + self, node: itir.Lambda, args: list[SymbolExpr | IteratorExpr | MemletExpr] ) -> tuple[ - list[ - tuple[ - dace.nodes.AccessNode, - dace.nodes.Tasklet, - str, - list[dace.symbolic.SymbolicType], - ] - ], + list[InputConnection], TaskletExpr, ]: for p, arg in zip(node.params, args, strict=True): @@ -322,7 +315,7 @@ def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = as_dace_type(node.type) return SymbolExpr(node.value, dtype) - def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | SymbolExpr | ValueExpr: + def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr | IteratorExpr | MemletExpr: param = str(node.id) assert param in self.symbol_map return self.symbol_map[param] From c45c41773de31da8df8c31abc9f4676f699d0cd9 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 16:59:07 +0200 Subject: [PATCH 049/235] Add support for programs without computation (pure memlets) --- .../runners/dace_fieldview/gtir_to_tasklet.py | 16 ++++++-- .../runners_tests/test_dace_fieldview.py | 37 +++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index dafb61bc34..591fc942a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -33,7 +33,7 @@ class MemletExpr: """Scalar or array data access thorugh a memlet.""" - data: dace.nodes.AccessNode + source: dace.nodes.AccessNode subset: sbs.Indices | sbs.Range @@ -172,7 +172,7 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml ) else: self.input_connections.append( - (arg_expr.data, arg_expr.subset, tasklet_node, connector) + (arg_expr.source, arg_expr.subset, tasklet_node, connector) ) return TaskletExpr(tasklet_node, "result") @@ -186,8 +186,16 @@ def visit_Lambda( for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg output_expr = self.visit(node.expr) - assert isinstance(output_expr, TaskletExpr) - return self.input_connections, output_expr + if isinstance(output_expr, TaskletExpr): + return self.input_connections, output_expr + + # special case where the field operator is simply copying data from source to destination node + assert isinstance(output_expr, MemletExpr) + tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self.input_connections.append( + (output_expr.source, output_expr.subset, tasklet_node, "__inp") + ) + return self.input_connections, TaskletExpr(tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = as_dace_type(node.type) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index a044bdc4db..d1122e0cc1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -50,6 +50,43 @@ OFFSET_PROVIDERS: dict[str, Connectivity | Dimension] = {} +def test_gtir_copy(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=Dim.value), 0, "size") + ) + testee = itir.Program( + id="gtir_copy", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.deref("a")), + domain, + ) + )("x"), + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg_genenerator = FieldviewGtirToSDFG( + [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + + sdfg(x=a, y=b, **FSYMBOLS) + assert np.allclose(a, b) + + def test_gtir_sum2(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") From d67518adb4ede17cd8dd2e378a7cfbab90d96955 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 13 May 2024 17:30:36 +0200 Subject: [PATCH 050/235] Fix test --- .../runners_tests/test_dace_fieldview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index d1122e0cc1..428ea97614 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -52,7 +52,7 @@ def test_gtir_copy(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=Dim.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") ) testee = itir.Program( id="gtir_copy", From 783542f1c9a0ca580ff11e9849fa4ef28b9ace41 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 14 May 2024 09:46:02 +0200 Subject: [PATCH 051/235] Fix for chain of shift expressions shift(V2E(E2V(i_edge, x), y))(edges) --- .../gtir_builtin_field_operator.py | 8 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 85 +++++++++++++---- .../runners_tests/test_dace_fieldview.py | 93 +++++++++++++++++-- 3 files changed, 157 insertions(+), 29 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 09e2d4facb..d8f25148e0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -140,7 +140,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: me, mx = self.head_state.add_map(unique_name("map"), map_ranges) for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1) + memlet = dace.Memlet(data=data_node.data, subset=data_subset) self.head_state.add_memlet_path( data_node, me, @@ -149,7 +149,11 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: memlet=memlet, ) self.head_state.add_memlet_path( - output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet + output_expr.node, + mx, + field_node, + src_conn=output_expr.connector, + memlet=output_memlet, ) return [(field_node, self.field_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 5c47db137e..dd5def0f64 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -75,6 +75,8 @@ class IteratorExpr: str, ] +INDEX_CONNECTOR_FMT = "__index_{dim}" + class GTIRToTasklet(eve.NodeVisitor): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" @@ -121,19 +123,18 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: return MemletExpr(it.field, data_index) else: - input_connector_fmt = "__inp_{dim}" assert all(dim in it.indices.keys() for dim in it.dimensions) index_connectors = [ - input_connector_fmt.format(dim=dim) + INDEX_CONNECTOR_FMT.format(dim=dim) for dim, index in it.indices.items() if not isinstance(index, SymbolExpr) ] - sorted_indices = [it.indices[dim] for dim in it.dimensions] + sorted_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_internals = ",".join( - index.value + str(index.value) if isinstance(index, SymbolExpr) - else input_connector_fmt.format(dim=dim) - for dim, index in zip(it.dimensions, sorted_indices) + else INDEX_CONNECTOR_FMT.format(dim=dim) + for dim, index in sorted_indices ) deref_node = self.state.add_tasklet( "deref_field_indirection", @@ -147,7 +148,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: self.input_connections.append((it.field, field_fullset, deref_node, "field")) for dim, index_expr in it.indices.items(): - deref_connector = input_connector_fmt.format(dim=dim) + deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) if isinstance(index_expr, MemletExpr): self.input_connections.append( ( @@ -221,10 +222,6 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices) else: # shift in unstructured domain by means of a neighbor table - origin_dim = offset_provider.origin_axis.value - assert origin_dim in it.indices - origin_index = it.indices[origin_dim] - assert isinstance(origin_index, SymbolExpr) neighbor_dim = offset_provider.neighbor_axis.value assert neighbor_dim in it.dimensions offset_table = connectivity_identifier(offset) @@ -233,16 +230,70 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: # so the corresponding arrays are supposed to be allocated by the SDFG caller self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) + + origin_dim = offset_provider.origin_axis.value + if origin_dim in it.indices: + origin_index = it.indices[origin_dim] + assert isinstance(origin_index, SymbolExpr) + if neighbor_dim in it.indices: + neighbor_index = it.indices[neighbor_dim] + assert isinstance(neighbor_index, TaskletExpr) + self.input_connections.append( + ( + offset_table_node, + sbs.Indices([origin_index.value, offset_value]), + neighbor_index.node, + INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), + ) + ) + shifted_indices = { + dim: index + for dim, index in it.indices.items() + if dim != origin_dim and dim != neighbor_dim + } | { + origin_dim: TaskletExpr( + neighbor_index.node, + INDEX_CONNECTOR_FMT.format(dim=origin_dim), + ) + } + else: + shifted_indices = { + dim: index for dim, index in it.indices.items() if dim != origin_dim + } | { + origin_dim: MemletExpr( + offset_table_node, + sbs.Indices([origin_index.value, offset_value]), + ) + } + else: + origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) + neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) + tasklet_node = self.state.add_tasklet( + "shift", + {"table", origin_index_connector}, + {neighbor_index_connector}, + f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_value}]", + ) + table_desc = offset_table_node.desc(self.sdfg) + self.input_connections.append( + ( + offset_table_node, + sbs.Range.from_array(table_desc), + tasklet_node, + "table", + ) + ) + shifted_indices = it.indices | { + origin_dim: TaskletExpr( + tasklet_node, + neighbor_index_connector, + ) + } shifted_it = IteratorExpr( it.field, [origin_dim if dim == neighbor_dim else dim for dim in it.dimensions], it.offset, - { - origin_dim: MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_value]), - ) - }, + shifted_indices, ) return shifted_it diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index d77c252a26..fc366eb632 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -49,12 +49,6 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) SIMPLE_MESH: MeshDescriptor = simple_mesh() FSYMBOLS = dict( - __edges_size_0=SIMPLE_MESH.num_edges, - __edges_stride_0=1, - __vertices_size_0=SIMPLE_MESH.num_vertices, - __vertices_stride_0=1, - nedges=SIMPLE_MESH.num_edges, - nvertices=SIMPLE_MESH.num_vertices, __w_size_0=N, __w_stride_0=1, __x_size_0=N, @@ -65,6 +59,22 @@ __z_stride_0=1, size=N, ) +CSYMBOLS = dict( + nedges=SIMPLE_MESH.num_edges, + nvertices=SIMPLE_MESH.num_vertices, + __edges_size_0=SIMPLE_MESH.num_edges, + __edges_stride_0=1, + __vertices_size_0=SIMPLE_MESH.num_vertices, + __vertices_stride_0=1, + __connectivity_E2V_size_0=SIMPLE_MESH.num_edges, + __connectivity_E2V_size_1=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_0=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_1=1, + __connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, + __connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_1=1, +) def test_gtir_copy(): @@ -496,9 +506,72 @@ def test_gtir_connectivity_shift(): vertices=v, connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - __connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, - __connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_1=1, + **CSYMBOLS, ) assert np.allclose(v, e[connectivity_V2E.table[:, 1]]) + + +def test_gtir_connectivity_shift_chain(): + E2V_neighbor_idx = 1 + V2E_neighbor_idx = 2 + edge_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "nedges") + ) + testee = itir.Program( + id="connectivity_shift_chain", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="edges_out"), + itir.Sym(id="nedges"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.deref( + im.shift("E2V", E2V_neighbor_idx)( + im.shift("V2E", V2E_neighbor_idx)("it") + ) + ) + ), + edge_domain, + ) + )("edges"), + domain=edge_domain, + target=itir.SymRef(id="edges_out"), + ) + ], + ) + + e = np.random.rand(SIMPLE_MESH.num_edges) + e_out = np.empty_like(e) + + sdfg_genenerator = FieldviewGtirToSDFG( + [EFTYPE, EFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider=SIMPLE_MESH.offset_provider, + ) + sdfg = sdfg_genenerator.visit(testee) + + assert isinstance(sdfg, dace.SDFG) + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] + assert isinstance(connectivity_E2V, NeighborTable) + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + sdfg( + edges=e, + edges_out=e_out, + connectivity_E2V=connectivity_E2V.table, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __edges_out_size_0=CSYMBOLS["__edges_size_0"], + __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], + ) + assert np.allclose( + e_out, + e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]], + ) From 1fa9de4869fa789e47362ccfd74cee061df48d62 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 14 May 2024 11:46:32 +0200 Subject: [PATCH 052/235] Support for multi-dimensional shift --- .../runners/dace_fieldview/gtir_to_tasklet.py | 53 +++--- .../runners_tests/test_dace_fieldview.py | 154 +++++++++++++----- 2 files changed, 133 insertions(+), 74 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index dd5def0f64..d4d89bf66e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -124,17 +124,17 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: else: assert all(dim in it.indices.keys() for dim in it.dimensions) + field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_connectors = [ INDEX_CONNECTOR_FMT.format(dim=dim) - for dim, index in it.indices.items() + for dim, index in field_indices if not isinstance(index, SymbolExpr) ] - sorted_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_internals = ",".join( str(index.value) if isinstance(index, SymbolExpr) else INDEX_CONNECTOR_FMT.format(dim=dim) - for dim, index in sorted_indices + for dim, index in field_indices ) deref_node = self.state.add_tasklet( "deref_field_indirection", @@ -147,7 +147,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: field_fullset = sbs.Range.from_array(field_desc) self.input_connections.append((it.field, field_fullset, deref_node, "field")) - for dim, index_expr in it.indices.items(): + for dim, index_expr in field_indices: deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) if isinstance(index_expr, MemletExpr): self.input_connections.append( @@ -224,10 +224,11 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: # shift in unstructured domain by means of a neighbor table neighbor_dim = offset_provider.neighbor_axis.value assert neighbor_dim in it.dimensions - offset_table = connectivity_identifier(offset) + # initially, the storage for the connectivty tables is created as transient # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller + offset_table = connectivity_identifier(offset) self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) @@ -235,32 +236,23 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: if origin_dim in it.indices: origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) - if neighbor_dim in it.indices: - neighbor_index = it.indices[neighbor_dim] - assert isinstance(neighbor_index, TaskletExpr) + neighbor_expr = it.indices.get(neighbor_dim, None) + if neighbor_expr is not None: + assert isinstance(neighbor_expr, TaskletExpr) self.input_connections.append( ( offset_table_node, sbs.Indices([origin_index.value, offset_value]), - neighbor_index.node, + neighbor_expr.node, INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), ) ) shifted_indices = { - dim: index - for dim, index in it.indices.items() - if dim != origin_dim and dim != neighbor_dim - } | { - origin_dim: TaskletExpr( - neighbor_index.node, - INDEX_CONNECTOR_FMT.format(dim=origin_dim), - ) - } + dim: index for dim, index in it.indices.items() if dim != neighbor_dim + } | {origin_dim: it.indices[neighbor_dim]} else: - shifted_indices = { - dim: index for dim, index in it.indices.items() if dim != origin_dim - } | { - origin_dim: MemletExpr( + shifted_indices = it.indices | { + neighbor_dim: MemletExpr( offset_table_node, sbs.Indices([origin_index.value, offset_value]), ) @@ -274,6 +266,10 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: {neighbor_index_connector}, f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_value}]", ) + neighbor_expr = TaskletExpr( + tasklet_node, + neighbor_index_connector, + ) table_desc = offset_table_node.desc(self.sdfg) self.input_connections.append( ( @@ -283,15 +279,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: "table", ) ) - shifted_indices = it.indices | { - origin_dim: TaskletExpr( - tasklet_node, - neighbor_index_connector, - ) - } + shifted_indices = it.indices | {origin_dim: neighbor_expr} + shifted_it = IteratorExpr( it.field, - [origin_dim if dim == neighbor_dim else dim for dim in it.dimensions], + [ + origin_dim if neighbor_expr and dim == neighbor_dim else dim + for dim in it.dimensions + ], it.offset, shifted_indices, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index fc366eb632..4c635e917a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -17,21 +17,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -from typing import Union -from gt4py.next.common import Connectivity, Dimension, NeighborTable +from gt4py.next.common import NeighborTable from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( GTIRToSDFG as FieldviewGtirToSDFG, ) from gt4py.next.type_system import type_specifications as ts - -import numpy as np - -import pytest - from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - V2E, + Cell, Edge, IDim, MeshDescriptor, @@ -40,11 +34,15 @@ ) from next_tests.integration_tests.cases import EField, IFloatField, VField +import numpy as np +import pytest + dace = pytest.importorskip("dace") N = 10 IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +CFTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) SIMPLE_MESH: MeshDescriptor = simple_mesh() @@ -60,12 +58,23 @@ size=N, ) CSYMBOLS = dict( + ncells=SIMPLE_MESH.num_cells, nedges=SIMPLE_MESH.num_edges, nvertices=SIMPLE_MESH.num_vertices, + __cells_size_0=SIMPLE_MESH.num_cells, + __cells_stride_0=1, __edges_size_0=SIMPLE_MESH.num_edges, __edges_stride_0=1, __vertices_size_0=SIMPLE_MESH.num_vertices, __vertices_stride_0=1, + __connectivity_C2E_size_0=SIMPLE_MESH.num_cells, + __connectivity_C2E_size_1=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_0=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_1=1, + __connectivity_C2V_size_0=SIMPLE_MESH.num_cells, + __connectivity_C2V_size_1=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_0=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_1=1, __connectivity_E2V_size_0=SIMPLE_MESH.num_edges, __connectivity_E2V_size_1=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, __connectivity_E2V_stride_0=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, @@ -107,7 +116,6 @@ def test_gtir_copy(): [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, y=b, **FSYMBOLS) @@ -145,7 +153,6 @@ def test_gtir_sum2(): [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, y=b, z=c, **FSYMBOLS) @@ -183,7 +190,6 @@ def test_gtir_sum2_sym(): offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, z=b, **FSYMBOLS) @@ -256,7 +262,6 @@ def test_gtir_sum3(): a = np.random.rand(N) b = np.random.rand(N) c = np.random.rand(N) - d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], @@ -267,6 +272,8 @@ def test_gtir_sum3(): sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) + d = np.empty_like(a) + sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + c)) @@ -324,7 +331,6 @@ def test_gtir_select(): a = np.random.rand(N) b = np.random.rand(N) c = np.random.rand(N) - d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( [ @@ -339,10 +345,10 @@ def test_gtir_select(): offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) for s in [False, True]: + d = np.empty_like(a) sdfg(cond=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) @@ -399,7 +405,6 @@ def test_gtir_select_nested(): ) a = np.random.rand(N) - b = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( [ @@ -412,11 +417,11 @@ def test_gtir_select_nested(): offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) for s1 in [False, True]: for s2 in [False, True]: + b = np.empty_like(a) sdfg(cond_1=np.bool_(s1), cond_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0)) @@ -451,7 +456,6 @@ def test_gtir_cartesian_shift(): [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim} ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) FSYMBOLS_tmp = FSYMBOLS.copy() @@ -461,54 +465,114 @@ def test_gtir_cartesian_shift(): def test_gtir_connectivity_shift(): - vertex_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices") + C2V_neighbor_idx = 1 + C2E_neighbor_idx = 2 + cell_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Cell.value), 0, "ncells"), ) - testee = itir.Program( - id="connectivity_shift", + # apply shift 2 times along different dimensions + testee1 = itir.Program( + id="connectivity_shift_1d", function_definitions=[], params=[ - itir.Sym(id="edges"), - itir.Sym(id="vertices"), - itir.Sym(id="nedges"), - itir.Sym(id="nvertices"), + itir.Sym(id="ve_field"), + itir.Sym(id="cells"), + itir.Sym(id="ncells"), ], declarations=[], body=[ itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 1)("it"))), - vertex_domain, + im.lambda_("it")( + im.deref( + im.shift("C2V", C2V_neighbor_idx)( + im.shift("C2E", C2E_neighbor_idx)("it") + ) + ) + ), + cell_domain, ) - )("edges"), - domain=vertex_domain, - target=itir.SymRef(id="vertices"), + )("ve_field"), + domain=cell_domain, + target=itir.SymRef(id="cells"), + ) + ], + ) + # multi-dimensional shift in one function call + testee2 = itir.Program( + id="connectivity_shift_2d", + function_definitions=[], + params=[ + itir.Sym(id="ve_field"), + itir.Sym(id="cells"), + itir.Sym(id="ncells"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.deref( + im.call( + im.call("shift")( + im.ensure_offset("C2V"), + im.ensure_offset(C2V_neighbor_idx), + im.ensure_offset("C2E"), + im.ensure_offset(C2E_neighbor_idx), + ) + )("it") + ) + ), + cell_domain, + ) + )("ve_field"), + domain=cell_domain, + target=itir.SymRef(id="cells"), ) ], ) - e = np.random.rand(SIMPLE_MESH.num_edges) - v = np.empty(SIMPLE_MESH.num_vertices) + ve = np.random.rand(SIMPLE_MESH.num_vertices, SIMPLE_MESH.num_edges) + VE_FTYPE = ts.FieldType(dims=[Vertex, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) sdfg_genenerator = FieldviewGtirToSDFG( - [EFTYPE, VFTYPE, ts.ScalarType(ts.ScalarKind.INT32), ts.ScalarType(ts.ScalarKind.INT32)], + [ + VE_FTYPE, + CFTYPE, + ts.ScalarType(ts.ScalarKind.INT32), + ], offset_provider=SIMPLE_MESH.offset_provider, ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) - connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] - assert isinstance(connectivity_V2E, NeighborTable) + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] + assert isinstance(connectivity_C2E, NeighborTable) + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] + assert isinstance(connectivity_C2V, NeighborTable) + ref = ve[ + connectivity_C2V.table[:, C2V_neighbor_idx], connectivity_C2E.table[:, C2E_neighbor_idx] + ] - sdfg( - edges=e, - vertices=v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **CSYMBOLS, - ) - assert np.allclose(v, e[connectivity_V2E.table[:, 1]]) + for testee in [testee1, testee2]: + sdfg = sdfg_genenerator.visit(testee) + assert isinstance(sdfg, dace.SDFG) + + c = np.empty(SIMPLE_MESH.num_cells) + + sdfg( + ve_field=ve, + cells=c, + connectivity_C2E=connectivity_C2E.table, + connectivity_C2V=connectivity_C2V.table, + **FSYMBOLS, + **CSYMBOLS, + __ve_field_size_0=SIMPLE_MESH.num_vertices, + __ve_field_size_1=SIMPLE_MESH.num_edges, + __ve_field_stride_0=SIMPLE_MESH.num_edges, + __ve_field_stride_1=1, + ) + assert np.allclose(c, ref) def test_gtir_connectivity_shift_chain(): From 96338c23acd51475809089928da4fc5471c297dd Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 14 May 2024 16:55:07 +0200 Subject: [PATCH 053/235] Fix typo --- .../runners_tests/test_dace_fieldview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 4c635e917a..77762f4876 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -431,7 +431,7 @@ def test_gtir_cartesian_shift(): im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( - id="caresian_shift", + id="cartesian_shift", function_definitions=[], params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="size")], declarations=[], From 57e369f25bd9188770c5dc4de3094fc55dd6c6cc Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 15 May 2024 11:53:58 +0200 Subject: [PATCH 054/235] Add support for cartesian shift with dynamic offset --- .../gtir_builtin_field_operator.py | 1 - .../runners/dace_fieldview/gtir_to_tasklet.py | 248 +++++++++++------- .../runners_tests/test_dace_fieldview.py | 73 +++++- 3 files changed, 216 insertions(+), 106 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index d8f25148e0..7bcc2e8997 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -108,7 +108,6 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: iterator_arg = IteratorExpr( data_node, [dim.value for dim in arg_type.dims], - sbs.Indices([0] * len(arg_type.dims)), indices, ) stencil_args.append(iterator_arg) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index d4d89bf66e..599b51ef17 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -58,14 +58,16 @@ class TaskletExpr: connector: str +IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | TaskletExpr + + @dataclass(frozen=True) class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode dimensions: list[str] - offset: list[dace.symbolic.SymExpr] - indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] + indices: dict[str, IteratorIndexExpr] InputConnection: TypeAlias = tuple[ @@ -99,6 +101,15 @@ def __init__( self.offset_provider = offset_provider self.symbol_map = {} + def _add_input_connection( + self, + src: dace.nodes.AccessNode, + subset: sbs.Range, + dst: dace.nodes.Tasklet, + dst_connector: str, + ) -> None: + self.input_connections.append((src, subset, dst, dst_connector)) + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -114,12 +125,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: elif isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # use direct field access through memlet subset - data_index = sbs.Indices( - [ - it.indices[dim].value + off # type: ignore[union-attr] - for dim, off in zip(it.dimensions, it.offset, strict=True) - ] - ) + data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] return MemletExpr(it.field, data_index) else: @@ -145,19 +151,18 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: # add new termination point for this field parameter field_desc = it.field.desc(self.sdfg) field_fullset = sbs.Range.from_array(field_desc) - self.input_connections.append((it.field, field_fullset, deref_node, "field")) + self._add_input_connection(it.field, field_fullset, deref_node, "field") for dim, index_expr in field_indices: deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) if isinstance(index_expr, MemletExpr): - self.input_connections.append( - ( - index_expr.source, - index_expr.subset, - deref_node, - deref_connector, - ) + self._add_input_connection( + index_expr.source, + index_expr.subset, + deref_node, + deref_connector, ) + elif isinstance(index_expr, TaskletExpr): self.state.add_edge( index_expr.node, @@ -189,6 +194,127 @@ def _make_shift_for_rest(self, rest: list[itir.Expr], iterator: itir.Expr) -> it args=[iterator], ) + def _make_cartesian_shift( + self, it: IteratorExpr, offset_dim: Dimension, offset_expr: IteratorIndexExpr + ) -> IteratorExpr: + """Implements cartesian offset along one dimension.""" + assert offset_dim.value in it.dimensions + new_index: SymbolExpr | TaskletExpr + assert offset_dim.value in it.indices + index_expr = it.indices[offset_dim.value] + if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): + new_index = SymbolExpr(index_expr.value + offset_expr.value, index_expr.dtype) + else: + # the offset needs to be calculate by means of a tasklet + new_index_connector = "shifted_index" + if isinstance(index_expr, SymbolExpr): + shift_tasklet = self.state.add_tasklet( + "cartesian_shift", + {"offset"}, + {new_index_connector}, + f"{new_index_connector} = {index_expr.value} + offset", + ) + elif isinstance(offset_expr, SymbolExpr): + shift_tasklet = self.state.add_tasklet( + "cartesian_shift", + {"index"}, + {new_index_connector}, + f"{new_index_connector} = index + {offset_expr}", + ) + else: + shift_tasklet = self.state.add_tasklet( + "cartesian_shift", + {"index", "offset"}, + {new_index_connector}, + f"{new_index_connector} = index + offset", + ) + for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: + if isinstance(input_expr, MemletExpr): + self._add_input_connection( + input_expr.source, input_expr.subset, shift_tasklet, input_connector + ) + elif isinstance(input_expr, TaskletExpr): + self.state.add_edge( + input_expr.node, + input_expr.connector, + shift_tasklet, + input_connector, + dace.Memlet(), + ) + + new_index = TaskletExpr(shift_tasklet, new_index_connector) + + return IteratorExpr( + it.field, + it.dimensions, + { + dim: (new_index if dim == offset_dim.value else index) + for dim, index in it.indices.items() + }, + ) + + def _make_unstructured_shift( + self, + it: IteratorExpr, + connectivity: Connectivity, + offset_table_node: dace.nodes.AccessNode, + offset_value: IteratorIndexExpr, + ) -> IteratorExpr: + # shift in unstructured domain by means of a neighbor table + neighbor_dim = connectivity.neighbor_axis.value + assert neighbor_dim in it.dimensions + + origin_dim = connectivity.origin_axis.value + if origin_dim in it.indices: + origin_index = it.indices[origin_dim] + assert isinstance(origin_index, SymbolExpr) + neighbor_expr = it.indices.get(neighbor_dim, None) + if neighbor_expr is not None: + assert isinstance(neighbor_expr, TaskletExpr) + self._add_input_connection( + offset_table_node, + sbs.Indices([origin_index.value, offset_value]), + neighbor_expr.node, + INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), + ) + shifted_indices = { + dim: index for dim, index in it.indices.items() if dim != neighbor_dim + } | {origin_dim: it.indices[neighbor_dim]} + else: + shifted_indices = it.indices | { + neighbor_dim: MemletExpr( + offset_table_node, + sbs.Indices([origin_index.value, offset_value]), + ) + } + else: + origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) + neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) + tasklet_node = self.state.add_tasklet( + "shift", + {"table", origin_index_connector}, + {neighbor_index_connector}, + f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_value}]", + ) + neighbor_expr = TaskletExpr( + tasklet_node, + neighbor_index_connector, + ) + table_desc = offset_table_node.desc(self.sdfg) + self._add_input_connection( + offset_table_node, + sbs.Range.from_array(table_desc), + tasklet_node, + "table", + ) + shifted_indices = it.indices | {origin_dim: neighbor_expr} + + return IteratorExpr( + it.field, + [origin_dim if neighbor_expr and dim == neighbor_dim else dim for dim in it.dimensions], + shifted_indices, + ) + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun assert isinstance(shift_node, itir.FunCall) @@ -206,25 +332,18 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset = head[0].value assert isinstance(offset, str) offset_provider = self.offset_provider[offset] - # second argument should be the offset value + # second argument should be the offset value, which could be a symbolic expression or a dynamic offset + offset_value: IteratorIndexExpr if isinstance(head[1], itir.OffsetLiteral): - assert isinstance(head[1].value, int) - offset_value = head[1].value + offset_value = SymbolExpr(head[1].value, dace.int32) else: - raise NotImplementedError("Dynamic offset not supported.") + dynamic_offset_expr = self.visit(head[1]) + assert isinstance(dynamic_offset_expr, MemletExpr | TaskletExpr) + offset_value = dynamic_offset_expr if isinstance(offset_provider, Dimension): - # cartesian offset along one dimension - new_offset = [ - prev_offset + offset_value if dim == offset_provider.value else prev_offset - for dim, prev_offset in zip(it.dimensions, it.offset, strict=True) - ] - shifted_it = IteratorExpr(it.field, it.dimensions, new_offset, it.indices) + return self._make_cartesian_shift(it, offset_provider, offset_value) else: - # shift in unstructured domain by means of a neighbor table - neighbor_dim = offset_provider.neighbor_axis.value - assert neighbor_dim in it.dimensions - # initially, the storage for the connectivty tables is created as transient # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller @@ -232,67 +351,10 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) - origin_dim = offset_provider.origin_axis.value - if origin_dim in it.indices: - origin_index = it.indices[origin_dim] - assert isinstance(origin_index, SymbolExpr) - neighbor_expr = it.indices.get(neighbor_dim, None) - if neighbor_expr is not None: - assert isinstance(neighbor_expr, TaskletExpr) - self.input_connections.append( - ( - offset_table_node, - sbs.Indices([origin_index.value, offset_value]), - neighbor_expr.node, - INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), - ) - ) - shifted_indices = { - dim: index for dim, index in it.indices.items() if dim != neighbor_dim - } | {origin_dim: it.indices[neighbor_dim]} - else: - shifted_indices = it.indices | { - neighbor_dim: MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_value]), - ) - } - else: - origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) - neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) - tasklet_node = self.state.add_tasklet( - "shift", - {"table", origin_index_connector}, - {neighbor_index_connector}, - f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_value}]", - ) - neighbor_expr = TaskletExpr( - tasklet_node, - neighbor_index_connector, - ) - table_desc = offset_table_node.desc(self.sdfg) - self.input_connections.append( - ( - offset_table_node, - sbs.Range.from_array(table_desc), - tasklet_node, - "table", - ) - ) - shifted_indices = it.indices | {origin_dim: neighbor_expr} - - shifted_it = IteratorExpr( - it.field, - [ - origin_dim if neighbor_expr and dim == neighbor_dim else dim - for dim in it.dimensions - ], - it.offset, - shifted_indices, + return self._make_unstructured_shift( + it, offset_provider, offset_table_node, offset_value ) - return shifted_it - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -339,8 +401,8 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() ) else: - self.input_connections.append( - (arg_expr.source, arg_expr.subset, tasklet_node, connector) + self._add_input_connection( + arg_expr.source, arg_expr.subset, tasklet_node, connector ) return TaskletExpr(tasklet_node, "result") @@ -360,9 +422,7 @@ def visit_Lambda( # special case where the field operator is simply copying data from source to destination node assert isinstance(output_expr, MemletExpr) tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self.input_connections.append( - (output_expr.source, output_expr.subset, tasklet_node, "__inp") - ) + self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") return self.input_connections, TaskletExpr(tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 77762f4876..36cecd07de 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -427,19 +427,21 @@ def test_gtir_select_nested(): def test_gtir_cartesian_shift(): + DELTA = 3 + OFFSET = 1 domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( + testee1 = itir.Program( id="cartesian_shift", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="size")], + params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], declarations=[], body=[ itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", 1)("a")), 1)), + im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), domain, ) )("x"), @@ -448,20 +450,69 @@ def test_gtir_cartesian_shift(): ) ], ) + testee2 = itir.Program( + id="dynamic_offset", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")( + im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA) + ), + domain, + ) + )("x", "x_offset"), + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) + testee3 = itir.Program( + id="dynamic_offset", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")( + im.plus( + im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA + ) + ), + domain, + ) + )("x", "x_offset"), + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) - a = np.random.rand(N + 1) + a = np.random.rand(N + OFFSET) + a_offset = np.full(N, OFFSET, dtype=np.int32) b = np.empty(N) + INDEX_FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim} + [IFTYPE, INDEX_FTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider={"IDim": IDim}, ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) - FSYMBOLS_tmp = FSYMBOLS.copy() - FSYMBOLS_tmp["__x_size_0"] = N + 1 - sdfg(x=a, y=b, **FSYMBOLS_tmp) - assert np.allclose(a[1:] + 1, b) + for testee in [testee1, testee2, testee3]: + sdfg = sdfg_genenerator.visit(testee) + assert isinstance(sdfg, dace.SDFG) + + FSYMBOLS_tmp = FSYMBOLS.copy() + FSYMBOLS_tmp["__x_size_0"] = N + OFFSET + FSYMBOLS_tmp["__x_offset_stride_0"] = 1 + sdfg(x=a, x_offset=a_offset, y=b, **FSYMBOLS_tmp) + assert np.allclose(a[OFFSET:] + DELTA, b) def test_gtir_connectivity_shift(): From ec4714c777934b9ec0a3c30c99826603040cb0b8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 15 May 2024 15:10:51 +0200 Subject: [PATCH 055/235] Add support for unstructured shift with dynamic offset --- .../runners/dace_fieldview/gtir_to_tasklet.py | 147 ++++++++++++++---- .../runners_tests/test_dace_fieldview.py | 56 ++++++- 2 files changed, 171 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 599b51ef17..c8090f2238 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -208,22 +208,22 @@ def _make_cartesian_shift( # the offset needs to be calculate by means of a tasklet new_index_connector = "shifted_index" if isinstance(index_expr, SymbolExpr): - shift_tasklet = self.state.add_tasklet( - "cartesian_shift", + dynamic_offset_tasklet = self.state.add_tasklet( + "dynamic_offset", {"offset"}, {new_index_connector}, f"{new_index_connector} = {index_expr.value} + offset", ) elif isinstance(offset_expr, SymbolExpr): - shift_tasklet = self.state.add_tasklet( - "cartesian_shift", + dynamic_offset_tasklet = self.state.add_tasklet( + "dynamic_offset", {"index"}, {new_index_connector}, f"{new_index_connector} = index + {offset_expr}", ) else: - shift_tasklet = self.state.add_tasklet( - "cartesian_shift", + dynamic_offset_tasklet = self.state.add_tasklet( + "dynamic_offset", {"index", "offset"}, {new_index_connector}, f"{new_index_connector} = index + offset", @@ -231,18 +231,21 @@ def _make_cartesian_shift( for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): self._add_input_connection( - input_expr.source, input_expr.subset, shift_tasklet, input_connector + input_expr.source, + input_expr.subset, + dynamic_offset_tasklet, + input_connector, ) elif isinstance(input_expr, TaskletExpr): self.state.add_edge( input_expr.node, input_expr.connector, - shift_tasklet, + dynamic_offset_tasklet, input_connector, dace.Memlet(), ) - new_index = TaskletExpr(shift_tasklet, new_index_connector) + new_index = TaskletExpr(dynamic_offset_tasklet, new_index_connector) return IteratorExpr( it.field, @@ -258,7 +261,7 @@ def _make_unstructured_shift( it: IteratorExpr, connectivity: Connectivity, offset_table_node: dace.nodes.AccessNode, - offset_value: IteratorIndexExpr, + offset_expr: IteratorIndexExpr, ) -> IteratorExpr: # shift in unstructured domain by means of a neighbor table neighbor_dim = connectivity.neighbor_axis.value @@ -271,31 +274,80 @@ def _make_unstructured_shift( neighbor_expr = it.indices.get(neighbor_dim, None) if neighbor_expr is not None: assert isinstance(neighbor_expr, TaskletExpr) - self._add_input_connection( - offset_table_node, - sbs.Indices([origin_index.value, offset_value]), - neighbor_expr.node, - INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), - ) + if isinstance(offset_expr, SymbolExpr): + # use memlet to retrieve the neighbor index and pass it to the index connector of tasklet for neighbor access + self._add_input_connection( + offset_table_node, + sbs.Indices([origin_index.value, offset_expr.value]), + neighbor_expr.node, + INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), + ) + else: + # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node + dynamic_offset_tasklet = self._make_dynamic_neighbor_offset( + offset_expr, offset_table_node, origin_index + ) + + # write result to the index connector of tasklet for neighbor access + self.state.add_edge( + dynamic_offset_tasklet.node, + dynamic_offset_tasklet.connector, + neighbor_expr.node, + INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), + ) + shifted_indices = { dim: index for dim, index in it.indices.items() if dim != neighbor_dim } | {origin_dim: it.indices[neighbor_dim]} - else: + + elif isinstance(offset_expr, SymbolExpr): + # use memlet to retrieve the neighbor index shifted_indices = it.indices | { neighbor_dim: MemletExpr( offset_table_node, - sbs.Indices([origin_index.value, offset_value]), + sbs.Indices([origin_index.value, offset_expr.value]), ) } + else: + # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node + dynamic_offset_tasklet = self._make_dynamic_neighbor_offset( + offset_expr, offset_table_node, origin_index + ) + + shifted_indices = it.indices | {neighbor_dim: dynamic_offset_tasklet} + else: origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) - tasklet_node = self.state.add_tasklet( - "shift", - {"table", origin_index_connector}, - {neighbor_index_connector}, - f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_value}]", - ) + if isinstance(offset_expr, SymbolExpr): + tasklet_node = self.state.add_tasklet( + "shift", + {"table", origin_index_connector}, + {neighbor_index_connector}, + f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_expr.value}]", + ) + else: + tasklet_node = self.state.add_tasklet( + "shift", + {"table", origin_index_connector, "offset"}, + {neighbor_index_connector}, + f"{neighbor_index_connector} = table[{origin_index_connector}, offset]", + ) + if isinstance(offset_expr, MemletExpr): + self._add_input_connection( + offset_expr.source, + offset_expr.subset, + tasklet_node, + "offset", + ) + else: + self.state.add_edge( + offset_expr.node, + offset_expr.connector, + tasklet_node, + "offset", + dace.Memlet(), + ) neighbor_expr = TaskletExpr( tasklet_node, neighbor_index_connector, @@ -315,6 +367,43 @@ def _make_unstructured_shift( shifted_indices, ) + def _make_dynamic_neighbor_offset( + self, + offset_expr: MemletExpr | TaskletExpr, + offset_table_node: dace.nodes.AccessNode, + origin_index: SymbolExpr, + ) -> TaskletExpr: + new_index_connector = "neighbor_index" + tasklet_node = self.state.add_tasklet( + "dynamic_neighbor_offset", + {"table", "offset"}, + {new_index_connector}, + f"{new_index_connector} = table[{origin_index.value}, offset]", + ) + self._add_input_connection( + offset_table_node, + sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + tasklet_node, + "table", + ) + if isinstance(offset_expr, MemletExpr): + self._add_input_connection( + offset_expr.source, + offset_expr.subset, + tasklet_node, + "offset", + ) + else: + self.state.add_edge( + offset_expr.node, + offset_expr.connector, + tasklet_node, + "offset", + dace.Memlet(), + ) + + return TaskletExpr(tasklet_node, new_index_connector) + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun assert isinstance(shift_node, itir.FunCall) @@ -333,16 +422,16 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: assert isinstance(offset, str) offset_provider = self.offset_provider[offset] # second argument should be the offset value, which could be a symbolic expression or a dynamic offset - offset_value: IteratorIndexExpr + offset_expr: IteratorIndexExpr if isinstance(head[1], itir.OffsetLiteral): - offset_value = SymbolExpr(head[1].value, dace.int32) + offset_expr = SymbolExpr(head[1].value, dace.int32) else: dynamic_offset_expr = self.visit(head[1]) assert isinstance(dynamic_offset_expr, MemletExpr | TaskletExpr) - offset_value = dynamic_offset_expr + offset_expr = dynamic_offset_expr if isinstance(offset_provider, Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_value) + return self._make_cartesian_shift(it, offset_provider, offset_expr) else: # initially, the storage for the connectivty tables is created as transient # when the tables are used, the storage is changed to non-transient, @@ -352,7 +441,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_table_node = self.state.add_access(offset_table) return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_value + it, offset_provider, offset_table_node, offset_expr ) def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 36cecd07de..85da5bbb11 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -432,6 +432,7 @@ def test_gtir_cartesian_shift(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) + # cartesian shift with literal integer offset testee1 = itir.Program( id="cartesian_shift", function_definitions=[], @@ -450,6 +451,7 @@ def test_gtir_cartesian_shift(): ) ], ) + # use dynamic offset retrieved from field testee2 = itir.Program( id="dynamic_offset", function_definitions=[], @@ -470,6 +472,7 @@ def test_gtir_cartesian_shift(): ) ], ) + # use the result of an arithmetic field operation as dynamic offset testee3 = itir.Program( id="dynamic_offset", function_definitions=[], @@ -497,10 +500,10 @@ def test_gtir_cartesian_shift(): a_offset = np.full(N, OFFSET, dtype=np.int32) b = np.empty(N) - INDEX_FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + IOFFSET_FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, INDEX_FTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + [IFTYPE, IOFFSET_FTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={"IDim": IDim}, ) @@ -527,6 +530,8 @@ def test_gtir_connectivity_shift(): function_definitions=[], params=[ itir.Sym(id="ve_field"), + itir.Sym(id="c2e_offset"), + itir.Sym(id="c2v_offset"), itir.Sym(id="cells"), itir.Sym(id="ncells"), ], @@ -556,6 +561,8 @@ def test_gtir_connectivity_shift(): function_definitions=[], params=[ itir.Sym(id="ve_field"), + itir.Sym(id="c2e_offset"), + itir.Sym(id="c2v_offset"), itir.Sym(id="cells"), itir.Sym(id="ncells"), ], @@ -584,13 +591,52 @@ def test_gtir_connectivity_shift(): ) ], ) + # again multi-dimensional shift in one function call, but this time with dynamic offset values + testee3 = itir.Program( + id="connectivity_shift_2d_dynamic_offset", + function_definitions=[], + params=[ + itir.Sym(id="ve_field"), + itir.Sym(id="c2e_offset"), + itir.Sym(id="c2v_offset"), + itir.Sym(id="cells"), + itir.Sym(id="ncells"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it", "c2e_off", "c2v_off")( + im.deref( + im.call( + im.call("shift")( + im.ensure_offset("C2V"), + im.deref("c2v_off"), + im.ensure_offset("C2E"), + im.plus(im.deref("c2e_off"), 0), + ) + )("it") + ) + ), + cell_domain, + ) + )("ve_field", "c2e_offset", "c2v_offset"), + domain=cell_domain, + target=itir.SymRef(id="cells"), + ) + ], + ) ve = np.random.rand(SIMPLE_MESH.num_vertices, SIMPLE_MESH.num_edges) VE_FTYPE = ts.FieldType(dims=[Vertex, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) sdfg_genenerator = FieldviewGtirToSDFG( [ VE_FTYPE, + CELL_OFFSET_FTYPE, + CELL_OFFSET_FTYPE, CFTYPE, ts.ScalarType(ts.ScalarKind.INT32), ], @@ -605,7 +651,7 @@ def test_gtir_connectivity_shift(): connectivity_C2V.table[:, C2V_neighbor_idx], connectivity_C2E.table[:, C2E_neighbor_idx] ] - for testee in [testee1, testee2]: + for testee in [testee1, testee2, testee3]: sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) @@ -613,6 +659,8 @@ def test_gtir_connectivity_shift(): sdfg( ve_field=ve, + c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), + c2v_offset=np.full(SIMPLE_MESH.num_cells, C2V_neighbor_idx, dtype=np.int32), cells=c, connectivity_C2E=connectivity_C2E.table, connectivity_C2V=connectivity_C2V.table, @@ -622,6 +670,8 @@ def test_gtir_connectivity_shift(): __ve_field_size_1=SIMPLE_MESH.num_edges, __ve_field_stride_0=SIMPLE_MESH.num_edges, __ve_field_stride_1=1, + __c2e_offset_stride_0=1, + __c2v_offset_stride_0=1, ) assert np.allclose(c, ref) From 46cb6c6c878bee4941331e695d3b14b3b927a7fd Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 15 May 2024 15:33:02 +0200 Subject: [PATCH 056/235] Code refactoring in test file --- .../runners_tests/test_dace_fieldview.py | 372 +++++++----------- 1 file changed, 152 insertions(+), 220 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 85da5bbb11..5bc7bb73ff 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -200,64 +200,28 @@ def test_gtir_sum3(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) - testee_fieldview = itir.Program( - id="sum_3fields", - function_definitions=[], - params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="size"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )( - "x", - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )("y", "w"), - ), - domain=domain, - target=itir.SymRef(id="z"), - ) - ], - ) - testee_inlined = itir.Program( - id="sum_3fields", - function_definitions=[], - params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="size"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "b", "c")( - im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) - ), - domain, - ) - )("x", "y", "w"), - domain=domain, - target=itir.SymRef(id="z"), + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "x", + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) - ], - ) + )("y", "w"), + ) + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b", "c")( + im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) + ), + domain, + ) + )("x", "y", "w") a = np.random.rand(N) b = np.random.rand(N) @@ -268,7 +232,27 @@ def test_gtir_sum3(): offset_provider={}, ) - for testee in [testee_fieldview, testee_inlined]: + for i, stencil in enumerate([stencil1, stencil2]): + testee = itir.Program( + id=f"sum_3fields_{i}", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="y"), + itir.Sym(id="w"), + itir.Sym(id="z"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) @@ -432,69 +416,34 @@ def test_gtir_cartesian_shift(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) + # cartesian shift with literal integer offset - testee1 = itir.Program( - id="cartesian_shift", - function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), - domain, - ) - )("x"), - domain=domain, - target=itir.SymRef(id="y"), - ) - ], - ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + domain, + ) + )("x") + # use dynamic offset retrieved from field - testee2 = itir.Program( - id="dynamic_offset", - function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA) - ), - domain, - ) - )("x", "x_offset"), - domain=domain, - target=itir.SymRef(id="y"), - ) - ], - ) + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")( + im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA) + ), + domain, + ) + )("x", "x_offset") + # use the result of an arithmetic field operation as dynamic offset - testee3 = itir.Program( - id="dynamic_offset", - function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="x_offset"), itir.Sym(id="y"), itir.Sym(id="size")], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "off")( - im.plus( - im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA - ) - ), - domain, - ) - )("x", "x_offset"), - domain=domain, - target=itir.SymRef(id="y"), - ) - ], - ) + stencil3 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")( + im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + ), + domain, + ) + )("x", "x_offset") a = np.random.rand(N + OFFSET) a_offset = np.full(N, OFFSET, dtype=np.int32) @@ -507,7 +456,26 @@ def test_gtir_cartesian_shift(): offset_provider={"IDim": IDim}, ) - for testee in [testee1, testee2, testee3]: + for i, stencil in enumerate([stencil1, stencil2, stencil3]): + testee = itir.Program( + id=f"dynamic_offset_{i}", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="x_offset"), + itir.Sym(id="y"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) + sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) @@ -525,108 +493,52 @@ def test_gtir_connectivity_shift(): im.call("named_range")(itir.AxisLiteral(value=Cell.value), 0, "ncells"), ) # apply shift 2 times along different dimensions - testee1 = itir.Program( - id="connectivity_shift_1d", - function_definitions=[], - params=[ - itir.Sym(id="ve_field"), - itir.Sym(id="c2e_offset"), - itir.Sym(id="c2v_offset"), - itir.Sym(id="cells"), - itir.Sym(id="ncells"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.deref( - im.shift("C2V", C2V_neighbor_idx)( - im.shift("C2E", C2E_neighbor_idx)("it") - ) - ) - ), - cell_domain, - ) - )("ve_field"), - domain=cell_domain, - target=itir.SymRef(id="cells"), - ) - ], - ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.deref(im.shift("C2V", C2V_neighbor_idx)(im.shift("C2E", C2E_neighbor_idx)("it"))) + ), + cell_domain, + ) + )("ve_field") + # multi-dimensional shift in one function call - testee2 = itir.Program( - id="connectivity_shift_2d", - function_definitions=[], - params=[ - itir.Sym(id="ve_field"), - itir.Sym(id="c2e_offset"), - itir.Sym(id="c2v_offset"), - itir.Sym(id="cells"), - itir.Sym(id="ncells"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.deref( - im.call( - im.call("shift")( - im.ensure_offset("C2V"), - im.ensure_offset(C2V_neighbor_idx), - im.ensure_offset("C2E"), - im.ensure_offset(C2E_neighbor_idx), - ) - )("it") - ) - ), - cell_domain, - ) - )("ve_field"), - domain=cell_domain, - target=itir.SymRef(id="cells"), - ) - ], - ) + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.deref( + im.call( + im.call("shift")( + im.ensure_offset("C2V"), + im.ensure_offset(C2V_neighbor_idx), + im.ensure_offset("C2E"), + im.ensure_offset(C2E_neighbor_idx), + ) + )("it") + ) + ), + cell_domain, + ) + )("ve_field") + # again multi-dimensional shift in one function call, but this time with dynamic offset values - testee3 = itir.Program( - id="connectivity_shift_2d_dynamic_offset", - function_definitions=[], - params=[ - itir.Sym(id="ve_field"), - itir.Sym(id="c2e_offset"), - itir.Sym(id="c2v_offset"), - itir.Sym(id="cells"), - itir.Sym(id="ncells"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it", "c2e_off", "c2v_off")( - im.deref( - im.call( - im.call("shift")( - im.ensure_offset("C2V"), - im.deref("c2v_off"), - im.ensure_offset("C2E"), - im.plus(im.deref("c2e_off"), 0), - ) - )("it") - ) - ), - cell_domain, - ) - )("ve_field", "c2e_offset", "c2v_offset"), - domain=cell_domain, - target=itir.SymRef(id="cells"), - ) - ], - ) + stencil3 = im.call( + im.call("as_fieldop")( + im.lambda_("it", "c2e_off", "c2v_off")( + im.deref( + im.call( + im.call("shift")( + im.ensure_offset("C2V"), + im.deref("c2v_off"), + im.ensure_offset("C2E"), + im.plus(im.deref("c2e_off"), 0), + ) + )("it") + ) + ), + cell_domain, + ) + )("ve_field", "c2e_offset", "c2v_offset") ve = np.random.rand(SIMPLE_MESH.num_vertices, SIMPLE_MESH.num_edges) VE_FTYPE = ts.FieldType(dims=[Vertex, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) @@ -651,7 +563,27 @@ def test_gtir_connectivity_shift(): connectivity_C2V.table[:, C2V_neighbor_idx], connectivity_C2E.table[:, C2E_neighbor_idx] ] - for testee in [testee1, testee2, testee3]: + for i, stencil in enumerate([stencil1, stencil2, stencil3]): + testee = itir.Program( + id=f"connectivity_shift_2d_{i}", + function_definitions=[], + params=[ + itir.Sym(id="ve_field"), + itir.Sym(id="c2e_offset"), + itir.Sym(id="c2v_offset"), + itir.Sym(id="cells"), + itir.Sym(id="ncells"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=cell_domain, + target=itir.SymRef(id="cells"), + ) + ], + ) + sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) From c20a94d0124795964f8783e705f73240baf03c7a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 15 May 2024 16:00:55 +0200 Subject: [PATCH 057/235] Typo --- .../runners_tests/test_dace_fieldview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 5bc7bb73ff..9a8ad7c8b8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -458,7 +458,7 @@ def test_gtir_cartesian_shift(): for i, stencil in enumerate([stencil1, stencil2, stencil3]): testee = itir.Program( - id=f"dynamic_offset_{i}", + id=f"cartesian_shift_{i}", function_definitions=[], params=[ itir.Sym(id="x"), From d1f74324981e2bae29f151ed6862d81f89093d4c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 15 May 2024 18:07:23 +0200 Subject: [PATCH 058/235] Code cleanup --- .../runners/dace_fieldview/gtir_to_tasklet.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index c8090f2238..cd60b17238 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -114,15 +114,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) - if isinstance(it, SymbolExpr): - cast_sym = str(it.dtype) - cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] - deref_node = self.state.add_tasklet( - "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" - ) - return TaskletExpr(deref_node, "val") - - elif isinstance(it, IteratorExpr): + if isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # use direct field access through memlet subset data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] From ed16fd4605b428c7ad3ecf98b0a670548b9a90ba Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 16 May 2024 10:46:24 +0200 Subject: [PATCH 059/235] Import updates from branch dace-fieldview-shifts --- .../gtir_builtin_field_operator.py | 9 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 40 ++--- .../runners_tests/test_dace_fieldview.py | 156 ++++++++---------- 3 files changed, 91 insertions(+), 114 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 09e2d4facb..7bcc2e8997 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -108,7 +108,6 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: iterator_arg = IteratorExpr( data_node, [dim.value for dim in arg_type.dims], - sbs.Indices([0] * len(arg_type.dims)), indices, ) stencil_args.append(iterator_arg) @@ -140,7 +139,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: me, mx = self.head_state.add_map(unique_name("map"), map_ranges) for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1) + memlet = dace.Memlet(data=data_node.data, subset=data_subset) self.head_state.add_memlet_path( data_node, me, @@ -149,7 +148,11 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: memlet=memlet, ) self.head_state.add_memlet_path( - output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet + output_expr.node, + mx, + field_node, + src_conn=output_expr.connector, + memlet=output_memlet, ) return [(field_node, self.field_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 591fc942a3..e6ad94449e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -53,14 +53,16 @@ class TaskletExpr: connector: str +IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | TaskletExpr + + @dataclass(frozen=True) class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode dimensions: list[str] - offset: list[dace.symbolic.SymExpr] - indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] + indices: dict[str, IteratorIndexExpr] InputConnection: TypeAlias = tuple[ @@ -92,27 +94,23 @@ def __init__( self.offset_provider = offset_provider self.symbol_map = {} + def _add_input_connection( + self, + src: dace.nodes.AccessNode, + subset: sbs.Range, + dst: dace.nodes.Tasklet, + dst_connector: str, + ) -> None: + self.input_connections.append((src, subset, dst, dst_connector)) + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) - if isinstance(it, SymbolExpr): - cast_sym = str(it.dtype) - cast_fmt = MATH_BUILTINS_MAPPING[cast_sym] - deref_node = self.state.add_tasklet( - "deref_symbol", {}, {"val"}, code=f"val = {cast_fmt.format(it.value)}" - ) - return TaskletExpr(deref_node, "val") - - elif isinstance(it, IteratorExpr): + if isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # use direct field access through memlet subset - data_index = sbs.Indices( - [ - it.indices[dim].value + off # type: ignore[union-attr] - for dim, off in zip(it.dimensions, it.offset, strict=True) - ] - ) + data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] return MemletExpr(it.field, data_index) else: @@ -171,8 +169,8 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() ) else: - self.input_connections.append( - (arg_expr.source, arg_expr.subset, tasklet_node, connector) + self._add_input_connection( + arg_expr.source, arg_expr.subset, tasklet_node, connector ) return TaskletExpr(tasklet_node, "result") @@ -192,9 +190,7 @@ def visit_Lambda( # special case where the field operator is simply copying data from source to destination node assert isinstance(output_expr, MemletExpr) tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self.input_connections.append( - (output_expr.source, output_expr.subset, tasklet_node, "__inp") - ) + self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") return self.input_connections, TaskletExpr(tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 428ea97614..ce654c853d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -17,25 +17,22 @@ Note: this test module covers the fieldview flavour of ITIR. """ -from typing import Union -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( GTIRToSDFG as FieldviewGtirToSDFG, ) from gt4py.next.type_system import type_specifications as ts +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim import numpy as np - import pytest dace = pytest.importorskip("dace") N = 10 -DIM = Dimension("D") -FTYPE = ts.FieldType(dims=[DIM], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) FSYMBOLS = dict( __w_size_0=N, __w_stride_0=1, @@ -47,12 +44,11 @@ __z_stride_0=1, size=N, ) -OFFSET_PROVIDERS: dict[str, Connectivity | Dimension] = {} def test_gtir_copy(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="gtir_copy", @@ -77,10 +73,9 @@ def test_gtir_copy(): b = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} + [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, y=b, **FSYMBOLS) @@ -89,7 +84,7 @@ def test_gtir_copy(): def test_gtir_sum2(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="sum_2fields", @@ -115,10 +110,9 @@ def test_gtir_sum2(): c = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, y=b, z=c, **FSYMBOLS) @@ -127,7 +121,7 @@ def test_gtir_sum2(): def test_gtir_sum2_sym(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="sum_2fields", @@ -152,10 +146,10 @@ def test_gtir_sum2_sym(): b = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) sdfg(x=a, z=b, **FSYMBOLS) @@ -164,87 +158,73 @@ def test_gtir_sum2_sym(): def test_gtir_sum3(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) - testee_fieldview = itir.Program( - id="sum_3fields", - function_definitions=[], - params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="size"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )( - "x", - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )("y", "w"), - ), - domain=domain, - target=itir.SymRef(id="z"), + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "x", + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, ) - ], - ) - testee_inlined = itir.Program( - id="sum_3fields", - function_definitions=[], - params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="size"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "b", "c")( - im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) - ), - domain, - ) - )("x", "y", "w"), - domain=domain, - target=itir.SymRef(id="z"), - ) - ], + )("y", "w"), ) + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b", "c")( + im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c"))) + ), + domain, + ) + )("x", "y", "w") a = np.random.rand(N) b = np.random.rand(N) c = np.random.rand(N) - d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( - [FTYPE, FTYPE, FTYPE, FTYPE, ts.ScalarType(ts.ScalarKind.INT32)], OFFSET_PROVIDERS + [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], + offset_provider={}, ) - for testee in [testee_fieldview, testee_inlined]: + for i, stencil in enumerate([stencil1, stencil2]): + testee = itir.Program( + id=f"sum_3fields_{i}", + function_definitions=[], + params=[ + itir.Sym(id="x"), + itir.Sym(id="y"), + itir.Sym(id="w"), + itir.Sym(id="z"), + itir.Sym(id="size"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + sdfg = sdfg_genenerator.visit(testee) assert isinstance(sdfg, dace.SDFG) + d = np.empty_like(a) + sdfg(x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + c)) def test_gtir_select(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="select_2sums", @@ -295,32 +275,31 @@ def test_gtir_select(): a = np.random.rand(N) b = np.random.rand(N) c = np.random.rand(N) - d = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( [ - FTYPE, - FTYPE, - FTYPE, - FTYPE, + IFTYPE, + IFTYPE, + IFTYPE, + IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.FLOAT64), ts.ScalarType(ts.ScalarKind.INT32), ], - OFFSET_PROVIDERS, + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) for s in [False, True]: + d = np.empty_like(a) sdfg(cond=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) def test_gtir_select_nested(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=DIM.value), 0, "size") + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( id="select_nested", @@ -370,23 +349,22 @@ def test_gtir_select_nested(): ) a = np.random.rand(N) - b = np.empty_like(a) sdfg_genenerator = FieldviewGtirToSDFG( [ - FTYPE, - FTYPE, + IFTYPE, + IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.INT32), ], - OFFSET_PROVIDERS, + offset_provider={}, ) sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) for s1 in [False, True]: for s2 in [False, True]: + b = np.empty_like(a) sdfg(cond_1=np.bool_(s1), cond_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0)) From 9f7176f79f918eb96f3922666e8731f30edce917 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 16 May 2024 14:10:32 +0200 Subject: [PATCH 060/235] Review comments --- .../gtir_builtin_field_operator.py | 16 +++-- .../gtir_builtins/gtir_builtin_select.py | 15 ++--- .../gtir_builtins/gtir_builtin_symbol_ref.py | 47 +++++++------- .../gtir_builtins/gtir_builtin_translator.py | 27 ++++---- .../runners/dace_fieldview/gtir_to_sdfg.py | 65 +++++++++---------- .../runners/dace_fieldview/gtir_to_tasklet.py | 6 +- .../runners/dace_fieldview/utility.py | 9 --- 7 files changed, 90 insertions(+), 95 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 7bcc2e8997..0564b15098 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable, TypeAlias +from typing import TypeAlias import dace import dace.subsets as sbs @@ -21,8 +21,11 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + SDFGField, + SDFGFieldBuilder, ) from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( GTIRToTasklet, @@ -31,7 +34,6 @@ SymbolExpr, TaskletExpr, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name from gt4py.next.type_system import type_specifications as ts @@ -45,7 +47,7 @@ class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] stencil_expr: itir.Lambda - stencil_args: list[Callable] + stencil_args: list[SDFGFieldBuilder] field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType offset_provider: dict[str, Connectivity | Dimension] @@ -55,7 +57,7 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, node: itir.FunCall, - stencil_args: list[Callable], + stencil_args: list[SDFGFieldBuilder], offset_provider: dict[str, Connectivity | Dimension], ): super().__init__(sdfg, state) @@ -69,7 +71,7 @@ def __init__( # the domain of the field operator is passed as second argument assert isinstance(domain_expr, itir.FunCall) - domain = get_domain(domain_expr) + domain = dace_fieldview_util.get_domain(domain_expr) # define field domain with all dimensions in alphabetical order sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) @@ -82,7 +84,7 @@ def __init__( self.stencil_expr = stencil_expr self.stencil_args = stencil_args - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def build(self) -> list[SDFGField]: dimension_index_fmt = "i_{dim}" # first visit the list of arguments and build a symbol map stencil_args: list[IteratorExpr | MemletExpr] = [] @@ -136,7 +138,7 @@ def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items() } - me, mx = self.head_state.add_map(unique_name("map"), map_ranges) + me, mx = self.head_state.add_map("field_op", map_ranges) for data_node, data_subset, lambda_node, lambda_connector in input_connections: memlet = dace.Memlet(data=data_node.data, subset=data_subset) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py index 7f7f9bacbe..05ed7bd74f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py @@ -13,25 +13,24 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable - import dace from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + SDFGField, + SDFGFieldBuilder, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import get_symbolic_expr -from gt4py.next.type_system import type_specifications as ts class GTIRBuiltinSelect(GTIRBuiltinTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" - true_br_builder: Callable - false_br_builder: Callable + true_br_builder: SDFGFieldBuilder + false_br_builder: SDFGFieldBuilder def __init__( self, @@ -47,7 +46,7 @@ def __init__( cond_expr, true_expr, false_expr = node.fun.args # expect condition as first argument - cond = get_symbolic_expr(cond_expr) + cond = dace_fieldview_util.get_symbolic_expr(cond_expr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: @@ -81,7 +80,7 @@ def __init__( false_expr, sdfg=sdfg, head_state=false_state ) - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def build(self) -> list[SDFGField]: # retrieve true/false states as predecessors of head state branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) assert len(branch_states) == 2 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index 35ca173369..d0a3afa497 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -19,6 +19,7 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( GTIRBuiltinTranslator, + SDFGField, ) from gt4py.next.type_system import type_specifications as ts @@ -52,29 +53,31 @@ def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: assert len(access_nodes) == 1 return access_nodes[0] - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def build(self) -> list[SDFGField]: + # check if access node is already present in current state sym_node = self._get_access_node() - if sym_node: - # if already present in current state, share access node - pass + if sym_node is None: + if isinstance(self.sym_type, ts.FieldType): + # add access node to current state + sym_node = self.head_state.add_access(self.sym_name) - elif isinstance(self.sym_type, ts.FieldType): - # add access node to current state - sym_node = self.head_state.add_access(self.sym_name) - - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - assert self.sym_name in self.sdfg.symbols - tasklet_node = self.head_state.add_tasklet( - f"get_{self.sym_name}", - {}, - {"__out"}, - f"__out = {self.sym_name}", - ) - sym_node = self.add_local_storage(self.sym_type, shape=[]) - self.head_state.add_edge( - tasklet_node, "__out", sym_node, None, dace.Memlet(data=sym_node.data, subset="0") - ) + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + assert self.sym_name in self.sdfg.symbols + tasklet_node = self.head_state.add_tasklet( + f"get_{self.sym_name}", + {}, + {"__out"}, + f"__out = {self.sym_name}", + ) + sym_node = self.add_local_storage(self.sym_type, shape=[]) + self.head_state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) return [(sym_node, self.sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 4063fe80be..996ff94377 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -13,25 +13,27 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from abc import abstractmethod +import abc from dataclasses import dataclass -from typing import final +from typing import Callable, TypeAlias, final import dace -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.type_system import type_specifications as ts +SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] +SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] + + @dataclass(frozen=True) -class GTIRBuiltinTranslator: +class GTIRBuiltinTranslator(abc.ABC): sdfg: dace.SDFG head_state: dace.SDFGState @final - def __call__( - self, - ) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + def __call__(self) -> list[SDFGField]: """The callable interface is used to build the dataflow graph. It allows to build the dataflow graph inside a given state starting @@ -45,19 +47,18 @@ def add_local_storage( self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] ) -> dace.nodes.AccessNode: """Allocates temporary storage to be used in the local scope for intermediate results.""" - name = unique_name("var") if isinstance(data_type, ts.FieldType): assert len(data_type.dims) == len(shape) - dtype = as_dace_type(data_type.dtype) - name, _ = self.sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True) + dtype = dace_fieldview_util.as_dace_type(data_type.dtype) + name, _ = self.sdfg.add_array("var", shape, dtype, find_new_name=True, transient=True) else: assert len(shape) == 0 - dtype = as_dace_type(data_type) + dtype = dace_fieldview_util.as_dace_type(data_type) name, _ = self.sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) return self.head_state.add_access(name) - @abstractmethod - def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: + @abc.abstractmethod + def build(self) -> list[SDFGField]: """Creates the dataflow subgraph representing a given GTIR builtin. This method is used by derived classes of `GTIRBuiltinTranslator`, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 79625a15ee..f493861f86 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -17,7 +17,7 @@ Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Callable, Mapping, Sequence +from typing import Any, Sequence import dace @@ -25,11 +25,12 @@ from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import gtir_builtins -from gt4py.next.program_processors.runners.dace_fieldview.utility import ( - as_dace_type, - filter_connectivities, - get_domain, +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_builtins, + utility as dace_fieldview_util, +) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( + SDFGFieldBuilder, ) from gt4py.next.type_system import type_specifications as ts @@ -47,22 +48,22 @@ class GTIRToSDFG(eve.NodeVisitor): from where to continue building the SDFG. """ - data_types: dict[str, ts.FieldType | ts.ScalarType] param_types: list[ts.DataType] offset_provider: dict[str, Connectivity | Dimension] + symbol_types: dict[str, ts.FieldType | ts.ScalarType] def __init__( self, param_types: list[ts.DataType], offset_provider: dict[str, Connectivity | Dimension], ): - self.data_types = {} self.param_types = param_types self.offset_provider = offset_provider + self.symbol_types = {} def _make_array_shape_and_strides( self, name: str, dims: Sequence[Dimension] - ) -> tuple[Sequence[dace.symbol], Sequence[dace.symbol]]: + ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -73,7 +74,7 @@ def _make_array_shape_and_strides( Two list of symbols, one for the shape and another for the strides of the array. """ dtype = dace.int32 - neighbor_tables = filter_connectivities(self.offset_provider) + neighbor_tables = dace_fieldview_util.filter_connectivities(self.offset_provider) shape = [ ( neighbor_tables[dim.value].max_neighbors @@ -86,30 +87,30 @@ def _make_array_shape_and_strides( strides = [dace.symbol(f"__{name}_stride_{i}", dtype) for i in range(len(dims))] return shape, strides - def _add_storage(self, sdfg: dace.SDFG, name: str, data_type: ts.DataType) -> None: + def _add_storage(self, sdfg: dace.SDFG, name: str, symbol_type: ts.DataType) -> None: """ Add external storage (aka non-transient) for data containers passed as arguments to the SDFG. For fields, it allocates dace arrays, while scalars are stored as SDFG symbols. """ - if isinstance(data_type, ts.FieldType): - dtype = as_dace_type(data_type.dtype) + if isinstance(symbol_type, ts.FieldType): + dtype = dace_fieldview_util.as_dace_type(symbol_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, data_type.dims) + sym_shape, sym_strides = self._make_array_shape_and_strides(name, symbol_type.dims) sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=False) - elif isinstance(data_type, ts.ScalarType): - dtype = as_dace_type(data_type) + elif isinstance(symbol_type, ts.ScalarType): + dtype = dace_fieldview_util.as_dace_type(symbol_type) # scalar arguments passed to the program are represented as symbols in DaCe SDFG sdfg.add_symbol(name, dtype) else: - raise RuntimeError(f"Data type '{type(data_type)}' not supported.") + raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.") # TODO: unclear why mypy complains about incompatible types - assert isinstance(data_type, (ts.FieldType, ts.ScalarType)) - self.data_types[name] = data_type + assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) + self.symbol_types[name] = symbol_type - def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> Mapping[str, str]: + def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> dict[str, str]: """ Add temporary storage (aka transient) for data containers used as GTIR temporaries. @@ -129,7 +130,6 @@ def _visit_expression( TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ expr_builder = self.visit(node, sdfg=sdfg, head_state=head_state) - assert callable(expr_builder) results = expr_builder() expressions_nodes = [] @@ -199,16 +199,16 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) target_nodes = self._visit_expression(stmt.target, sdfg, state) # convert domain expression to dictionary to ease access to dimension boundaries - domain = get_domain(stmt.domain) + domain = dace_fieldview_util.get_domain(stmt.domain) for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] assert not target_array.transient - target_field_type = self.data_types[target_node.data] + target_symbol_type = self.symbol_types[target_node.data] - if isinstance(target_field_type, ts.FieldType): + if isinstance(target_symbol_type, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_field_type.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims ) else: assert len(domain) == 0 @@ -222,12 +222,11 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) def visit_FunCall( self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> Callable: + ) -> SDFGFieldBuilder: # first visit the argument nodes - arg_builders: list[Callable] = [] + arg_builders: list[SDFGFieldBuilder] = [] for arg in node.args: arg_builder = self.visit(arg, sdfg=sdfg, head_state=head_state) - assert callable(arg_builder) arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): @@ -252,8 +251,8 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> Callable: - sym_name = str(node.id) - assert sym_name in self.data_types - sym_type = self.data_types[sym_name] - return gtir_builtins.SymbolRef(sdfg, head_state, sym_name, sym_type) + ) -> SDFGFieldBuilder: + symbol_name = str(node.id) + assert symbol_name in self.symbol_types + symbol_type = self.symbol_types[symbol_name] + return gtir_builtins.SymbolRef(sdfg, head_state, symbol_name, symbol_type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index e6ad94449e..7f32c82037 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -23,10 +23,10 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( MATH_BUILTINS_MAPPING, ) -from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name @dataclass(frozen=True) @@ -157,7 +157,7 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml out_connector = "result" tasklet_node = self.state.add_tasklet( - unique_name("tasklet"), + builtin_name, node_connections.keys(), {out_connector}, "{} = {}".format(out_connector, code), @@ -194,7 +194,7 @@ def visit_Lambda( return self.input_connections, TaskletExpr(tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: - dtype = as_dace_type(node.type) + dtype = dace_fieldview_util.as_dace_type(node.type) return SymbolExpr(node.value, dtype) def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr | IteratorExpr | MemletExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 2e8a205ca8..8d4b99ffaa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -85,12 +85,3 @@ def get_domain( def get_symbolic_expr(node: itir.Expr) -> str: return GTIRPythonCodegen().visit(node) - - -def unique_name(prefix: str) -> str: - """Generate a string containing a unique integer id, which is updated incrementally.""" - - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" From 932db7c3d85680c1ef887098c791edfe55055002 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 16 May 2024 17:35:59 +0200 Subject: [PATCH 061/235] Avoid tasklet-to-tasklet edge connections --- .../gtir_builtin_field_operator.py | 17 ++- .../gtir_builtins/gtir_builtin_translator.py | 2 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 4 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 134 +++++++++++------- 4 files changed, 100 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 0564b15098..fd7f01bfe7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -32,7 +32,7 @@ IteratorExpr, MemletExpr, SymbolExpr, - TaskletExpr, + ValueExpr, ) from gt4py.next.type_system import type_specifications as ts @@ -100,7 +100,7 @@ def build(self) -> list[SDFGField]: stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = { + indices: dict[str, MemletExpr | SymbolExpr | ValueExpr] = { dim.value: SymbolExpr( dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), _INDEX_DTYPE, @@ -117,7 +117,14 @@ def build(self) -> list[SDFGField]: # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, TaskletExpr) + assert isinstance(output_expr, ValueExpr) + + # retrieve the tasklet node which writes the result + output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src + output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn + + # the last transient node can be deleted + self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field field_shape = [ @@ -150,10 +157,10 @@ def build(self) -> list[SDFGField]: memlet=memlet, ) self.head_state.add_memlet_path( - output_expr.node, + output_tasklet_node, mx, field_node, - src_conn=output_expr.connector, + src_conn=output_tasklet_connector, memlet=output_memlet, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 996ff94377..41ab2d6ca9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -54,7 +54,7 @@ def add_local_storage( else: assert len(shape) == 0 dtype = dace_fieldview_util.as_dace_type(data_type) - name, _ = self.sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) return self.head_state.add_access(name) @abc.abstractmethod diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 9c56a258f6..5967fe51e3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -131,7 +131,7 @@ def _visit_expression( TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ - expr_builder = self.visit(node, sdfg=sdfg, head_state=head_state) + expr_builder: SDFGFieldBuilder = self.visit(node, sdfg=sdfg, head_state=head_state) results = expr_builder() expressions_nodes = [] @@ -245,7 +245,7 @@ def visit_FunCall( # first visit the argument nodes arg_builders: list[SDFGFieldBuilder] = [] for arg in node.args: - arg_builder = self.visit(arg, sdfg=sdfg, head_state=head_state) + arg_builder: SDFGFieldBuilder = self.visit(arg, sdfg=sdfg, head_state=head_state) arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index afb9e4514a..aae0e7746b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -28,6 +28,7 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( MATH_BUILTINS_MAPPING, ) +from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) @@ -47,14 +48,13 @@ class SymbolExpr: @dataclass(frozen=True) -class TaskletExpr: +class ValueExpr: """Result of the computation provided by a tasklet node.""" - node: dace.nodes.Tasklet - connector: str + node: dace.nodes.AccessNode -IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | TaskletExpr +IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @dataclass(frozen=True) @@ -83,7 +83,7 @@ class GTIRToTasklet(eve.NodeVisitor): state: dace.SDFGState input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] - symbol_map: dict[str, SymbolExpr | IteratorExpr | MemletExpr] + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] def __init__( self, @@ -106,7 +106,21 @@ def _add_input_connection( ) -> None: self.input_connections.append((src, subset, dst, dst_connector)) - def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: + def _get_tasklet_result( + self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str + ) -> ValueExpr: + scalar_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) + scalar_node = self.state.add_access(scalar_name) + self.state.add_edge( + src_node, + src_connector, + scalar_node, + None, + dace.Memlet(data=scalar_node.data, subset="0"), + ) + return ValueExpr(scalar_node) + + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -151,18 +165,19 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: deref_connector, ) - elif isinstance(index_expr, TaskletExpr): + elif isinstance(index_expr, ValueExpr): self.state.add_edge( index_expr.node, - index_expr.connector, + None, deref_node, deref_connector, - dace.Memlet(), + dace.Memlet(data=index_expr.node.data, subset="0"), ) else: assert isinstance(index_expr, SymbolExpr) - return TaskletExpr(deref_node, "val") + dtype = it.field.desc(self.sdfg).dtype + return self._get_tasklet_result(dtype, deref_node, "val") else: assert isinstance(it, MemletExpr) @@ -187,7 +202,7 @@ def _make_cartesian_shift( ) -> IteratorExpr: """Implements cartesian offset along one dimension.""" assert offset_dim.value in it.dimensions - new_index: SymbolExpr | TaskletExpr + new_index: SymbolExpr | ValueExpr assert offset_dim.value in it.indices index_expr = it.indices[offset_dim.value] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): @@ -218,22 +233,30 @@ def _make_cartesian_shift( ) for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): + if input_connector == "index": + dtype = input_expr.source.desc(self.sdfg).dtype self._add_input_connection( input_expr.source, input_expr.subset, dynamic_offset_tasklet, input_connector, ) - elif isinstance(input_expr, TaskletExpr): + elif isinstance(input_expr, ValueExpr): + if input_connector == "index": + dtype = input_expr.node.desc(self.sdfg).dtype self.state.add_edge( input_expr.node, - input_expr.connector, + None, dynamic_offset_tasklet, input_connector, - dace.Memlet(), + dace.Memlet(data=input_expr.node.data, subset="0"), ) + else: + assert isinstance(input_expr, SymbolExpr) + if input_connector == "index": + dtype = input_expr.dtype - new_index = TaskletExpr(dynamic_offset_tasklet, new_index_connector) + new_index = self._get_tasklet_result(dtype, dynamic_offset_tasklet, new_index_connector) return IteratorExpr( it.field, @@ -261,27 +284,30 @@ def _make_unstructured_shift( assert isinstance(origin_index, SymbolExpr) neighbor_expr = it.indices.get(neighbor_dim, None) if neighbor_expr is not None: - assert isinstance(neighbor_expr, TaskletExpr) + assert isinstance(neighbor_expr, ValueExpr) + # retrieve the tasklet that perform the neighbor table access + neighbor_tasklet_node = self.state.in_edges(neighbor_expr.node)[0].src if isinstance(offset_expr, SymbolExpr): # use memlet to retrieve the neighbor index and pass it to the index connector of tasklet for neighbor access self._add_input_connection( offset_table_node, sbs.Indices([origin_index.value, offset_expr.value]), - neighbor_expr.node, + neighbor_tasklet_node, INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - dynamic_offset_tasklet = self._make_dynamic_neighbor_offset( + dynamic_offset_value = self._make_dynamic_neighbor_offset( offset_expr, offset_table_node, origin_index ) # write result to the index connector of tasklet for neighbor access self.state.add_edge( - dynamic_offset_tasklet.node, - dynamic_offset_tasklet.connector, - neighbor_expr.node, + dynamic_offset_value.node, + None, + neighbor_tasklet_node, INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), + memlet=dace.Memlet(data=dynamic_offset_value.node.data, subset="0"), ) shifted_indices = { @@ -298,11 +324,11 @@ def _make_unstructured_shift( } else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - dynamic_offset_tasklet = self._make_dynamic_neighbor_offset( + dynamic_offset_value = self._make_dynamic_neighbor_offset( offset_expr, offset_table_node, origin_index ) - shifted_indices = it.indices | {neighbor_dim: dynamic_offset_tasklet} + shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} else: origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) @@ -331,16 +357,17 @@ def _make_unstructured_shift( else: self.state.add_edge( offset_expr.node, - offset_expr.connector, + None, tasklet_node, "offset", - dace.Memlet(), + dace.Memlet(data=offset_expr.node.data, subset="0"), ) - neighbor_expr = TaskletExpr( + table_desc = offset_table_node.desc(self.sdfg) + neighbor_expr = self._get_tasklet_result( + table_desc.dtype, tasklet_node, neighbor_index_connector, ) - table_desc = offset_table_node.desc(self.sdfg) self._add_input_connection( offset_table_node, sbs.Range.from_array(table_desc), @@ -357,10 +384,10 @@ def _make_unstructured_shift( def _make_dynamic_neighbor_offset( self, - offset_expr: MemletExpr | TaskletExpr, + offset_expr: MemletExpr | ValueExpr, offset_table_node: dace.nodes.AccessNode, origin_index: SymbolExpr, - ) -> TaskletExpr: + ) -> ValueExpr: new_index_connector = "neighbor_index" tasklet_node = self.state.add_tasklet( "dynamic_neighbor_offset", @@ -384,13 +411,14 @@ def _make_dynamic_neighbor_offset( else: self.state.add_edge( offset_expr.node, - offset_expr.connector, + None, tasklet_node, "offset", - dace.Memlet(), + dace.Memlet(data=offset_expr.node.data, subset="0"), ) - return TaskletExpr(tasklet_node, new_index_connector) + dtype = offset_table_node.desc(self.sdfg).dtype + return self._get_tasklet_result(dtype, tasklet_node, new_index_connector) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun @@ -415,7 +443,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_expr = SymbolExpr(head[1].value, dace.int32) else: dynamic_offset_expr = self.visit(head[1]) - assert isinstance(dynamic_offset_expr, MemletExpr | TaskletExpr) + assert isinstance(dynamic_offset_expr, MemletExpr | ValueExpr) offset_expr = dynamic_offset_expr if isinstance(offset_provider, Dimension): @@ -432,7 +460,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: it, offset_provider, offset_table_node, offset_expr ) - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -443,10 +471,10 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml assert isinstance(node.fun, itir.SymRef) node_internals = [] - node_connections: dict[str, MemletExpr | TaskletExpr] = {} + node_connections: dict[str, MemletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, MemletExpr | TaskletExpr): + if isinstance(arg_expr, MemletExpr | ValueExpr): # the argument value is the result of a tasklet node or direct field access connector = f"__inp_{i}" node_connections[connector] = arg_expr @@ -473,40 +501,48 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml ) for connector, arg_expr in node_connections.items(): - if isinstance(arg_expr, TaskletExpr): + if isinstance(arg_expr, ValueExpr): self.state.add_edge( - arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() + arg_expr.node, + None, + tasklet_node, + connector, + dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: self._add_input_connection( arg_expr.source, arg_expr.subset, tasklet_node, connector ) - return TaskletExpr(tasklet_node, "result") + # TODO: use type inference to determine the result type + if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): + dtype = node_connections["__inp_0"].source.desc(self.sdfg).dtype + else: + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + dtype = dace_fieldview_util.as_dace_type(node_type) + + return self._get_tasklet_result(dtype, tasklet_node, "result") def visit_Lambda( - self, node: itir.Lambda, args: list[SymbolExpr | IteratorExpr | MemletExpr] - ) -> tuple[ - list[InputConnection], - TaskletExpr, - ]: + self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] + ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr = self.visit(node.expr) - if isinstance(output_expr, TaskletExpr): + output_expr: MemletExpr | ValueExpr = self.visit(node.expr) + if isinstance(output_expr, ValueExpr): return self.input_connections, output_expr # special case where the field operator is simply copying data from source to destination node - assert isinstance(output_expr, MemletExpr) + output_dtype = output_expr.source.desc(self.sdfg).dtype tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") - return self.input_connections, TaskletExpr(tasklet_node, "__out") + return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) return SymbolExpr(node.value, dtype) - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr | IteratorExpr | MemletExpr: + def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: param = str(node.id) assert param in self.symbol_map return self.symbol_map[param] From 46febb0a66756f63a0845d1dcc6722a28a394dab Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 16 May 2024 17:41:43 +0200 Subject: [PATCH 062/235] Avoid tasklet-to-tasklet edge connections --- .../gtir_builtin_field_operator.py | 17 +++-- .../gtir_builtins/gtir_builtin_translator.py | 2 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 4 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 66 ++++++++++++------- 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py index 0564b15098..fd7f01bfe7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py @@ -32,7 +32,7 @@ IteratorExpr, MemletExpr, SymbolExpr, - TaskletExpr, + ValueExpr, ) from gt4py.next.type_system import type_specifications as ts @@ -100,7 +100,7 @@ def build(self) -> list[SDFGField]: stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = { + indices: dict[str, MemletExpr | SymbolExpr | ValueExpr] = { dim.value: SymbolExpr( dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), _INDEX_DTYPE, @@ -117,7 +117,14 @@ def build(self) -> list[SDFGField]: # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, TaskletExpr) + assert isinstance(output_expr, ValueExpr) + + # retrieve the tasklet node which writes the result + output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src + output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn + + # the last transient node can be deleted + self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field field_shape = [ @@ -150,10 +157,10 @@ def build(self) -> list[SDFGField]: memlet=memlet, ) self.head_state.add_memlet_path( - output_expr.node, + output_tasklet_node, mx, field_node, - src_conn=output_expr.connector, + src_conn=output_tasklet_connector, memlet=output_memlet, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py index 996ff94377..41ab2d6ca9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py @@ -54,7 +54,7 @@ def add_local_storage( else: assert len(shape) == 0 dtype = dace_fieldview_util.as_dace_type(data_type) - name, _ = self.sdfg.add_scalar(name, dtype, find_new_name=True, transient=True) + name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) return self.head_state.add_access(name) @abc.abstractmethod diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index f493861f86..7ae52d0fcc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -129,7 +129,7 @@ def _visit_expression( TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ - expr_builder = self.visit(node, sdfg=sdfg, head_state=head_state) + expr_builder: SDFGFieldBuilder = self.visit(node, sdfg=sdfg, head_state=head_state) results = expr_builder() expressions_nodes = [] @@ -226,7 +226,7 @@ def visit_FunCall( # first visit the argument nodes arg_builders: list[SDFGFieldBuilder] = [] for arg in node.args: - arg_builder = self.visit(arg, sdfg=sdfg, head_state=head_state) + arg_builder: SDFGFieldBuilder = self.visit(arg, sdfg=sdfg, head_state=head_state) arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 7f32c82037..306e44a2a9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -27,6 +27,7 @@ from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( MATH_BUILTINS_MAPPING, ) +from gt4py.next.type_system import type_specifications as ts @dataclass(frozen=True) @@ -46,14 +47,13 @@ class SymbolExpr: @dataclass(frozen=True) -class TaskletExpr: +class ValueExpr: """Result of the computation provided by a tasklet node.""" - node: dace.nodes.Tasklet - connector: str + node: dace.nodes.AccessNode -IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | TaskletExpr +IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @dataclass(frozen=True) @@ -80,7 +80,7 @@ class GTIRToTasklet(eve.NodeVisitor): state: dace.SDFGState input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] - symbol_map: dict[str, SymbolExpr | IteratorExpr | MemletExpr] + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] def __init__( self, @@ -103,7 +103,21 @@ def _add_input_connection( ) -> None: self.input_connections.append((src, subset, dst, dst_connector)) - def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: + def _get_tasklet_result( + self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str + ) -> ValueExpr: + scalar_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) + scalar_node = self.state.add_access(scalar_name) + self.state.add_edge( + src_node, + src_connector, + scalar_node, + None, + dace.Memlet(data=scalar_node.data, subset="0"), + ) + return ValueExpr(scalar_node) + + def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -123,7 +137,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | TaskletExpr: def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: raise NotImplementedError - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | MemletExpr: + def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -134,10 +148,10 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml assert isinstance(node.fun, itir.SymRef) node_internals = [] - node_connections: dict[str, MemletExpr | TaskletExpr] = {} + node_connections: dict[str, MemletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, MemletExpr | TaskletExpr): + if isinstance(arg_expr, MemletExpr | ValueExpr): # the argument value is the result of a tasklet node or direct field access connector = f"__inp_{i}" node_connections[connector] = arg_expr @@ -164,40 +178,48 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | TaskletExpr | Meml ) for connector, arg_expr in node_connections.items(): - if isinstance(arg_expr, TaskletExpr): + if isinstance(arg_expr, ValueExpr): self.state.add_edge( - arg_expr.node, arg_expr.connector, tasklet_node, connector, dace.Memlet() + arg_expr.node, + None, + tasklet_node, + connector, + dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: self._add_input_connection( arg_expr.source, arg_expr.subset, tasklet_node, connector ) - return TaskletExpr(tasklet_node, "result") + # TODO: use type inference to determine the result type + if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): + dtype = node_connections["__inp_0"].source.desc(self.sdfg).dtype + else: + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + dtype = dace_fieldview_util.as_dace_type(node_type) + + return self._get_tasklet_result(dtype, tasklet_node, "result") def visit_Lambda( - self, node: itir.Lambda, args: list[SymbolExpr | IteratorExpr | MemletExpr] - ) -> tuple[ - list[InputConnection], - TaskletExpr, - ]: + self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] + ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr = self.visit(node.expr) - if isinstance(output_expr, TaskletExpr): + output_expr: MemletExpr | ValueExpr = self.visit(node.expr) + if isinstance(output_expr, ValueExpr): return self.input_connections, output_expr # special case where the field operator is simply copying data from source to destination node - assert isinstance(output_expr, MemletExpr) + output_dtype = output_expr.source.desc(self.sdfg).dtype tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") - return self.input_connections, TaskletExpr(tasklet_node, "__out") + return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) return SymbolExpr(node.value, dtype) - def visit_SymRef(self, node: itir.SymRef) -> SymbolExpr | IteratorExpr | MemletExpr: + def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: param = str(node.id) assert param in self.symbol_map return self.symbol_map[param] From 949bad780c699d9a76d4b9515cd9e2a9826f2ba0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 16 May 2024 17:53:33 +0200 Subject: [PATCH 063/235] Add support for in-out field parameters --- .../gtir_builtins/gtir_builtin_symbol_ref.py | 59 +++++++------------ .../runners_tests/test_dace_fieldview.py | 36 +++++++++++ 2 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py index d0a3afa497..fa073e81fb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py @@ -13,8 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Optional - import dace from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( @@ -41,43 +39,28 @@ def __init__( self.sym_name = sym_name self.sym_type = sym_type - def _get_access_node(self) -> Optional[dace.nodes.AccessNode]: - """Returns, if present, the access node in current state for the data symbol.""" - access_nodes = [ - node - for node in self.head_state.nodes() - if isinstance(node, dace.nodes.AccessNode) and node.data == self.sym_name - ] - if len(access_nodes) == 0: - return None - assert len(access_nodes) == 1 - return access_nodes[0] - def build(self) -> list[SDFGField]: - # check if access node is already present in current state - sym_node = self._get_access_node() - if sym_node is None: - if isinstance(self.sym_type, ts.FieldType): - # add access node to current state - sym_node = self.head_state.add_access(self.sym_name) + if isinstance(self.sym_type, ts.FieldType): + # add access node to current state + sym_node = self.head_state.add_access(self.sym_name) - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - assert self.sym_name in self.sdfg.symbols - tasklet_node = self.head_state.add_tasklet( - f"get_{self.sym_name}", - {}, - {"__out"}, - f"__out = {self.sym_name}", - ) - sym_node = self.add_local_storage(self.sym_type, shape=[]) - self.head_state.add_edge( - tasklet_node, - "__out", - sym_node, - None, - dace.Memlet(data=sym_node.data, subset="0"), - ) + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + assert self.sym_name in self.sdfg.symbols + tasklet_node = self.head_state.add_tasklet( + f"get_{self.sym_name}", + {}, + {"__out"}, + f"__out = {self.sym_name}", + ) + sym_node = self.add_local_storage(self.sym_type, shape=[]) + self.head_state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) return [(sym_node, self.sym_type)] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index ce654c853d..c5c45fac3d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -82,6 +82,42 @@ def test_gtir_copy(): assert np.allclose(a, b) +def test_gtir_update(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = itir.Program( + id="gtir_copy", + function_definitions=[], + params=[itir.Sym(id="x"), itir.Sym(id="size")], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + domain, + ) + )("x"), + domain=domain, + target=itir.SymRef(id="x"), + ) + ], + ) + + a = np.random.rand(N) + ref = a + 1.0 + + sdfg_genenerator = FieldviewGtirToSDFG( + [IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} + ) + sdfg = sdfg_genenerator.visit(testee) + assert isinstance(sdfg, dace.SDFG) + + sdfg(x=a, **FSYMBOLS) + assert np.allclose(a, ref) + + def test_gtir_sum2(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") From 8890f95a00b6569bdc74f304f3a94193ee00d04e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 09:10:30 +0200 Subject: [PATCH 064/235] Refactoring: import modules, not symbols --- .../__init__.py | 14 ++- .../gtir_builtin.py} | 9 +- .../gtir_builtin_field_operator.py | 44 +++----- .../gtir_builtin_select.py | 14 +-- .../gtir_builtin_symbol_ref.py | 9 +- .../dace_fieldview/gtir_python_codegen.py | 104 ------------------ .../runners/dace_fieldview/gtir_to_sdfg.py | 25 +++-- .../runners/dace_fieldview/gtir_to_tasklet.py | 99 ++++++++++++++++- .../runners/dace_fieldview/utility.py | 6 +- 9 files changed, 154 insertions(+), 170 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins => gtir_builtin_translators}/__init__.py (70%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins/gtir_builtin_translator.py => gtir_builtin_translators/gtir_builtin.py} (95%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins => gtir_builtin_translators}/gtir_builtin_field_operator.py (84%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins => gtir_builtin_translators}/gtir_builtin_select.py (94%) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_builtins => gtir_builtin_translators}/gtir_builtin_symbol_ref.py (92%) delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py similarity index 70% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py index c99b418eae..7b4c9c1e63 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py @@ -12,19 +12,27 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin import ( + GTIRPrimitiveTranslator as PrimitiveTranslator, + SDFGField, + SDFGFieldBuilder, +) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_field_operator import ( GTIRBuiltinAsFieldOp as AsFieldOp, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_select import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_select import ( GTIRBuiltinSelect as Select, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_symbol_ref import ( +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_symbol_ref import ( GTIRBuiltinSymbolRef as SymbolRef, ) # export short names of translation classes for GTIR builtin functions __all__ = [ + "PrimitiveTranslator", + "SDFGField", + "SDFGFieldBuilder", "AsFieldOp", "Select", "SymbolRef", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py similarity index 95% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py index 41ab2d6ca9..50242385e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py @@ -19,16 +19,19 @@ import dace -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util +from gt4py.next.program_processors.runners.dace_fieldview import ( + utility as dace_fieldview_util, +) from gt4py.next.type_system import type_specifications as ts +# Define aliases for return types SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] @dataclass(frozen=True) -class GTIRBuiltinTranslator(abc.ABC): +class GTIRPrimitiveTranslator(abc.ABC): sdfg: dace.SDFG head_state: dace.SDFGState @@ -69,4 +72,4 @@ def build(self) -> list[SDFGField]: The GT4Py data type is useful in the case of fields, because it provides information on the field domain (e.g. order of dimensions, types of dimensions). - """ + """ \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py similarity index 84% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py index fd7f01bfe7..aed987d375 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py @@ -21,33 +21,23 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GTIRBuiltinTranslator, - SDFGField, - SDFGFieldBuilder, +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_to_tasklet, + utility as dace_fieldview_util, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( - GTIRToTasklet, - IteratorExpr, - MemletExpr, - SymbolExpr, - ValueExpr, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( + gtir_builtin, ) from gt4py.next.type_system import type_specifications as ts -# Define type of variables used for field indexing -_INDEX_DTYPE = dace.int64 - - -class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): +class GTIRBuiltinAsFieldOp(gtir_builtin.GTIRPrimitiveTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] stencil_expr: itir.Lambda - stencil_args: list[SDFGFieldBuilder] + stencil_args: list[gtir_builtin.SDFGFieldBuilder] field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType offset_provider: dict[str, Connectivity | Dimension] @@ -57,7 +47,7 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, node: itir.FunCall, - stencil_args: list[SDFGFieldBuilder], + stencil_args: list[gtir_builtin.SDFGFieldBuilder], offset_provider: dict[str, Connectivity | Dimension], ): super().__init__(sdfg, state) @@ -84,10 +74,10 @@ def __init__( self.stencil_expr = stencil_expr self.stencil_args = stencil_args - def build(self) -> list[SDFGField]: + def build(self) -> list[gtir_builtin.SDFGField]: dimension_index_fmt = "i_{dim}" # first visit the list of arguments and build a symbol map - stencil_args: list[IteratorExpr | MemletExpr] = [] + stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] for arg in self.stencil_args: arg_nodes = arg() assert len(arg_nodes) == 1 @@ -96,18 +86,18 @@ def build(self) -> list[SDFGField]: assert isinstance(data_node, dace.nodes.AccessNode) if isinstance(arg_type, ts.ScalarType): - scalar_arg = MemletExpr(data_node, sbs.Indices([0])) + scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[str, MemletExpr | SymbolExpr | ValueExpr] = { - dim.value: SymbolExpr( + indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { + dim.value: gtir_to_tasklet.SymbolExpr( dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), - _INDEX_DTYPE, + gtir_to_tasklet.INDEX_DTYPE, ) for dim in self.field_domain.keys() } - iterator_arg = IteratorExpr( + iterator_arg = gtir_to_tasklet.IteratorExpr( data_node, [dim.value for dim in arg_type.dims], indices, @@ -115,9 +105,9 @@ def build(self) -> list[SDFGField]: stencil_args.append(iterator_arg) # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) + taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, ValueExpr) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) # retrieve the tasklet node which writes the result output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py similarity index 94% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py index 05ed7bd74f..b0786cd6d7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py @@ -19,18 +19,16 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GTIRBuiltinTranslator, - SDFGField, - SDFGFieldBuilder, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( + gtir_builtin, ) -class GTIRBuiltinSelect(GTIRBuiltinTranslator): +class GTIRBuiltinSelect(gtir_builtin.GTIRPrimitiveTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" - true_br_builder: SDFGFieldBuilder - false_br_builder: SDFGFieldBuilder + true_br_builder: gtir_builtin.SDFGFieldBuilder + false_br_builder: gtir_builtin.SDFGFieldBuilder def __init__( self, @@ -80,7 +78,7 @@ def __init__( false_expr, sdfg=sdfg, head_state=false_state ) - def build(self) -> list[SDFGField]: + def build(self) -> list[gtir_builtin.SDFGField]: # retrieve true/false states as predecessors of head state branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) assert len(branch_states) == 2 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py similarity index 92% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py index fa073e81fb..e7ae631cfd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_symbol_ref.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py @@ -15,14 +15,13 @@ import dace -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - GTIRBuiltinTranslator, - SDFGField, +from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( + gtir_builtin, ) from gt4py.next.type_system import type_specifications as ts -class GTIRBuiltinSymbolRef(GTIRBuiltinTranslator): +class GTIRBuiltinSymbolRef(gtir_builtin.GTIRPrimitiveTranslator): """Generates the dataflow subgraph for a `itir.SymRef` node.""" sym_name: str @@ -39,7 +38,7 @@ def __init__( self.sym_name = sym_name self.sym_type = sym_type - def build(self) -> list[SDFGField]: + def build(self) -> list[gtir_builtin.SDFGField]: if isinstance(self.sym_type, ts.FieldType): # add access node to current state sym_node = self.head_state.add_access(self.sym_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py deleted file mode 100644 index 478b5d3af8..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ /dev/null @@ -1,104 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import numpy as np - -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -class GTIRPythonCodegen(codegen.TemplatedGenerator): - SymRef = as_fmt("{id}") - Literal = as_fmt("{value}") - - def _visit_deref(self, node: itir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - elif isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 7ae52d0fcc..663fdc2513 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -26,12 +26,9 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_builtins, + gtir_builtin_translators, utility as dace_fieldview_util, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( - SDFGFieldBuilder, -) from gt4py.next.type_system import type_specifications as ts @@ -129,7 +126,9 @@ def _visit_expression( TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ - expr_builder: SDFGFieldBuilder = self.visit(node, sdfg=sdfg, head_state=head_state) + expr_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( + node, sdfg=sdfg, head_state=head_state + ) results = expr_builder() expressions_nodes = [] @@ -222,21 +221,23 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) def visit_FunCall( self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> SDFGFieldBuilder: + ) -> gtir_builtin_translators.SDFGFieldBuilder: # first visit the argument nodes - arg_builders: list[SDFGFieldBuilder] = [] + arg_builders = [] for arg in node.args: - arg_builder: SDFGFieldBuilder = self.visit(arg, sdfg=sdfg, head_state=head_state) + arg_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( + arg, sdfg=sdfg, head_state=head_state + ) arg_builders.append(arg_builder) if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtins.AsFieldOp( + return gtir_builtin_translators.AsFieldOp( sdfg, head_state, node, arg_builders, self.offset_provider ) elif cpm.is_call_to(node.fun, "select"): assert len(arg_builders) == 0 - return gtir_builtins.Select(sdfg, head_state, self, node) + return gtir_builtin_translators.Select(sdfg, head_state, self, node) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -251,8 +252,8 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> SDFGFieldBuilder: + ) -> gtir_builtin_translators.SDFGFieldBuilder: symbol_name = str(node.id) assert symbol_name in self.symbol_types symbol_type = self.symbol_types[symbol_name] - return gtir_builtins.SymbolRef(sdfg, head_state, symbol_name, symbol_type) + return gtir_builtin_translators.SymbolRef(sdfg, head_state, symbol_name, symbol_type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 306e44a2a9..074e69e884 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -23,12 +23,16 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util -from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( - MATH_BUILTINS_MAPPING, +from gt4py.next.program_processors.runners.dace_fieldview import ( + utility as dace_fieldview_util, ) from gt4py.next.type_system import type_specifications as ts +import numpy as np +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @dataclass(frozen=True) class MemletExpr: @@ -73,7 +77,66 @@ class IteratorExpr: ] -class GTIRToTasklet(eve.NodeVisitor): +# Define type of variables used for field indexing +INDEX_DTYPE = dace.int32 + + +MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +class LambdaToTasklet(eve.NodeVisitor): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" sdfg: dace.SDFG @@ -223,3 +286,31 @@ def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | MemletExpr | SymbolE param = str(node.id) assert param in self.symbol_map return self.symbol_map[param] + + +class PythonCodegen(codegen.TemplatedGenerator): + SymRef = as_fmt("{id}") + Literal = as_fmt("{value}") + + def _visit_deref(self, node: itir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], itir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: itir.FunCall) -> str: + assert isinstance(node.fun, itir.SymRef) + fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: itir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, itir.SymRef): + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + return self._visit_numeric_builtin(node) + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 8d4b99ffaa..99c46b2f1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -19,9 +19,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview.gtir_python_codegen import ( - GTIRPythonCodegen, -) +from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet from gt4py.next.type_system import type_specifications as ts @@ -84,4 +82,4 @@ def get_domain( def get_symbolic_expr(node: itir.Expr) -> str: - return GTIRPythonCodegen().visit(node) + return gtir_to_tasklet.PythonCodegen().visit(node) From 87b71a67f38f0e351b23785fb834bd89c6b75e2f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 09:28:33 +0200 Subject: [PATCH 065/235] Minor edit --- .../gtir_builtin_translators/gtir_builtin.py | 6 ++---- .../gtir_builtin_field_operator.py | 4 +++- .../runners/dace_fieldview/gtir_to_tasklet.py | 18 +++++------------- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py index 50242385e5..0c887dcd9d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py @@ -19,9 +19,7 @@ import dace -from gt4py.next.program_processors.runners.dace_fieldview import ( - utility as dace_fieldview_util, -) +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.type_system import type_specifications as ts @@ -72,4 +70,4 @@ def build(self) -> list[SDFGField]: The GT4Py data type is useful in the case of fields, because it provides information on the field domain (e.g. order of dimensions, types of dimensions). - """ \ No newline at end of file + """ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py index aed987d375..572c531a46 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py @@ -76,6 +76,8 @@ def __init__( def build(self) -> list[gtir_builtin.SDFGField]: dimension_index_fmt = "i_{dim}" + # type of variables used for field indexing + index_dtype = dace.int32 # first visit the list of arguments and build a symbol map stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] for arg in self.stencil_args: @@ -93,7 +95,7 @@ def build(self) -> list[gtir_builtin.SDFGField]: indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { dim.value: gtir_to_tasklet.SymbolExpr( dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), - gtir_to_tasklet.INDEX_DTYPE, + index_dtype, ) for dim in self.field_domain.keys() } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 074e69e884..5c06694315 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -18,21 +18,17 @@ import dace import dace.subsets as sbs +import numpy as np from gt4py import eve +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import ( - utility as dace_fieldview_util, -) +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util from gt4py.next.type_system import type_specifications as ts -import numpy as np -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @dataclass(frozen=True) class MemletExpr: @@ -77,10 +73,6 @@ class IteratorExpr: ] -# Define type of variables used for field indexing -INDEX_DTYPE = dace.int32 - - MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -313,4 +305,4 @@ def visit_FunCall(self, node: itir.FunCall) -> str: return self._visit_numeric_builtin(node) else: raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") \ No newline at end of file + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") From 665a6098b257cf6e8544868d39ed2adc14acb218 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 10:49:37 +0200 Subject: [PATCH 066/235] Remove internal package for builtin translators --- .../gtir_builtin_translators.py | 354 ++++++++++++++++++ .../gtir_builtin_translators/__init__.py | 39 -- .../gtir_builtin_translators/gtir_builtin.py | 73 ---- .../gtir_builtin_field_operator.py | 159 -------- .../gtir_builtin_select.py | 122 ------ .../gtir_builtin_symbol_ref.py | 65 ---- .../runners/dace_fieldview/gtir_to_sdfg.py | 25 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 16 +- .../runners/dace_fieldview/utility.py | 8 +- 9 files changed, 388 insertions(+), 473 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py new file mode 100644 index 0000000000..e71a7e5606 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -0,0 +1,354 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import abc +from dataclasses import dataclass +from typing import Callable, TypeAlias, final + +import dace +import dace.subsets as sbs + +from gt4py import eve +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_to_tasklet, + utility as dace_fieldview_util, +) +from gt4py.next.type_system import type_specifications as ts + + +# Define aliases for return types +SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] +SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] + + +@dataclass(frozen=True) +class PrimitiveTranslator(abc.ABC): + sdfg: dace.SDFG + head_state: dace.SDFGState + + @final + def __call__(self) -> list[SDFGField]: + """The callable interface is used to build the dataflow graph. + + It allows to build the dataflow graph inside a given state starting + from the innermost nodes, by propagating the intermediate results + as access nodes to temporary local storage. + """ + return self.build() + + @final + def add_local_storage( + self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + ) -> dace.nodes.AccessNode: + """ + Allocates temporary storage to be used in the local scope for intermediate results. + + The storage is allocate with unique names to enable SSA optimization in the compilation phase. + """ + if isinstance(data_type, ts.FieldType): + assert len(data_type.dims) == len(shape) + dtype = dace_fieldview_util.as_dace_type(data_type.dtype) + name, _ = self.sdfg.add_array("var", shape, dtype, find_new_name=True, transient=True) + else: + assert len(shape) == 0 + dtype = dace_fieldview_util.as_dace_type(data_type) + name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) + return self.head_state.add_access(name) + + @abc.abstractmethod + def build(self) -> list[SDFGField]: + """Creates the dataflow subgraph representing a GTIR builtin function. + + This method is used by derived classes to build a specialized subgraph + for a specific builtin function. + + Returns a list of SDFG nodes and the associated GT4Py data type. + + The GT4Py data type is useful in the case of fields, because it provides + information on the field domain (e.g. order of dimensions, types of dimensions). + """ + + +class AsFieldOp(PrimitiveTranslator): + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + + TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] + + stencil_expr: itir.Lambda + stencil_args: list[SDFGFieldBuilder] + field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] + field_type: ts.FieldType + offset_provider: dict[str, Connectivity | Dimension] + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + node: itir.FunCall, + stencil_args: list[SDFGFieldBuilder], + offset_provider: dict[str, Connectivity | Dimension], + ): + super().__init__(sdfg, state) + self.offset_provider = offset_provider + + assert cpm.is_call_to(node.fun, "as_fieldop") + assert len(node.fun.args) == 2 + stencil_expr, domain_expr = node.fun.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, itir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, itir.FunCall) + + domain = dace_fieldview_util.get_domain(domain_expr) + # define field domain with all dimensions in alphabetical order + sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + + self.field_domain = domain + self.field_type = ts.FieldType(sorted_domain_dims, node_type) + self.stencil_expr = stencil_expr + self.stencil_args = stencil_args + + def build(self) -> list[SDFGField]: + dimension_index_fmt = "i_{dim}" + # type of variables used for field indexing + index_dtype = dace.int32 + # first visit the list of arguments and build a symbol map + stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] + for arg in self.stencil_args: + arg_nodes = arg() + assert len(arg_nodes) == 1 + data_node, arg_type = arg_nodes[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + stencil_args.append(scalar_arg) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { + dim.value: gtir_to_tasklet.SymbolExpr( + dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), + index_dtype, + ) + for dim in self.field_domain.keys() + } + iterator_arg = gtir_to_tasklet.IteratorExpr( + data_node, + [dim.value for dim in arg_type.dims], + indices, + ) + stencil_args.append(iterator_arg) + + # represent the field operator as a mapped tasklet graph, which will range over the field domain + taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) + input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + + # retrieve the tasklet node which writes the result + output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src + output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn + + # the last transient node can be deleted + self.head_state.remove_node(output_expr.node) + + # allocate local temporary storage for the result field + field_shape = [ + # diff between upper and lower bound + self.field_domain[dim][1] - self.field_domain[dim][0] + for dim in self.field_type.dims + ] + field_node = self.add_local_storage(self.field_type, field_shape) + + # assume tasklet with single output + output_index = ",".join( + dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims + ) + output_memlet = dace.Memlet(data=field_node.data, subset=output_index) + + # create map range corresponding to the field operator domain + map_ranges = { + dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" + for dim, (lb, ub) in self.field_domain.items() + } + me, mx = self.head_state.add_map("field_op", map_ranges) + + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + self.head_state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) + self.head_state.add_memlet_path( + output_tasklet_node, + mx, + field_node, + src_conn=output_tasklet_connector, + memlet=output_memlet, + ) + + return [(field_node, self.field_type)] + + +class Select(PrimitiveTranslator): + """Generates the dataflow subgraph for the `select` builtin function.""" + + true_br_builder: SDFGFieldBuilder + false_br_builder: SDFGFieldBuilder + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + dataflow_builder: eve.NodeVisitor, + node: itir.FunCall, + ): + super().__init__(sdfg, state) + + assert cpm.is_call_to(node.fun, "select") + assert len(node.fun.args) == 3 + cond_expr, true_expr, false_expr = node.fun.args + + # expect condition as first argument + cond = dace_fieldview_util.get_symbolic_expr(cond_expr) + + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | select | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # + select_state = sdfg.add_state_before(state, state.label + "_select") + sdfg.remove_edge(sdfg.out_edges(select_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}"))) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + self.false_br_builder = dataflow_builder.visit( + false_expr, sdfg=sdfg, head_state=false_state + ) + + def build(self) -> list[SDFGField]: + # retrieve true/false states as predecessors of head state + branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) + assert len(branch_states) == 2 + if branch_states[0].label.endswith("_true_branch"): + true_state, false_state = branch_states + else: + false_state, true_state = branch_states + + true_br_args = self.true_br_builder() + false_br_args = self.false_br_builder() + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args, strict=True): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, false_br_type = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + assert true_br_type == false_br_type + array_type = self.sdfg.arrays[true_br_node.data] + access_node = self.add_local_storage(true_br_type, array_type.shape) + output_nodes.append((access_node, true_br_type)) + + data_name = access_node.data + true_br_output_node = true_state.add_access(data_name) + true_state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array( + true_br_output_node.data, true_br_output_node.desc(self.sdfg) + ), + ) + + false_br_output_node = false_state.add_access(data_name) + false_state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array( + false_br_output_node.data, false_br_output_node.desc(self.sdfg) + ), + ) + return output_nodes + + +class SymbolRef(PrimitiveTranslator): + """Generates the dataflow subgraph for a `ir.SymRef` node.""" + + sym_name: str + sym_type: ts.FieldType | ts.ScalarType + + def __init__( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + sym_name: str, + sym_type: ts.FieldType | ts.ScalarType, + ): + super().__init__(sdfg, state) + self.sym_name = sym_name + self.sym_type = sym_type + + def build(self) -> list[SDFGField]: + if isinstance(self.sym_type, ts.FieldType): + # add access node to current state + sym_node = self.head_state.add_access(self.sym_name) + + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + assert self.sym_name in self.sdfg.symbols + tasklet_node = self.head_state.add_tasklet( + f"get_{self.sym_name}", + {}, + {"__out"}, + f"__out = {self.sym_name}", + ) + sym_node = self.add_local_storage(self.sym_type, shape=[]) + self.head_state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) + + return [(sym_node, self.sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py deleted file mode 100644 index 7b4c9c1e63..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin import ( - GTIRPrimitiveTranslator as PrimitiveTranslator, - SDFGField, - SDFGFieldBuilder, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_field_operator import ( - GTIRBuiltinAsFieldOp as AsFieldOp, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_select import ( - GTIRBuiltinSelect as Select, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators.gtir_builtin_symbol_ref import ( - GTIRBuiltinSymbolRef as SymbolRef, -) - - -# export short names of translation classes for GTIR builtin functions -__all__ = [ - "PrimitiveTranslator", - "SDFGField", - "SDFGFieldBuilder", - "AsFieldOp", - "Select", - "SymbolRef", -] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py deleted file mode 100644 index 0c887dcd9d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin.py +++ /dev/null @@ -1,73 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import abc -from dataclasses import dataclass -from typing import Callable, TypeAlias, final - -import dace - -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util -from gt4py.next.type_system import type_specifications as ts - - -# Define aliases for return types -SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] -SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] - - -@dataclass(frozen=True) -class GTIRPrimitiveTranslator(abc.ABC): - sdfg: dace.SDFG - head_state: dace.SDFGState - - @final - def __call__(self) -> list[SDFGField]: - """The callable interface is used to build the dataflow graph. - - It allows to build the dataflow graph inside a given state starting - from the innermost nodes, by propagating the intermediate results - as access nodes to temporary local storage. - """ - return self.build() - - @final - def add_local_storage( - self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] - ) -> dace.nodes.AccessNode: - """Allocates temporary storage to be used in the local scope for intermediate results.""" - if isinstance(data_type, ts.FieldType): - assert len(data_type.dims) == len(shape) - dtype = dace_fieldview_util.as_dace_type(data_type.dtype) - name, _ = self.sdfg.add_array("var", shape, dtype, find_new_name=True, transient=True) - else: - assert len(shape) == 0 - dtype = dace_fieldview_util.as_dace_type(data_type) - name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) - return self.head_state.add_access(name) - - @abc.abstractmethod - def build(self) -> list[SDFGField]: - """Creates the dataflow subgraph representing a given GTIR builtin. - - This method is used by derived classes of `GTIRBuiltinTranslator`, - which build a specialized subgraph for a certain GTIR builtin. - - Returns a list of SDFG nodes and the associated GT4Py data type: - tuple(node, data_type) - - The GT4Py data type is useful in the case of fields, because it provides - information on the field domain (e.g. order of dimensions, types of dimensions). - """ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py deleted file mode 100644 index 572c531a46..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_field_operator.py +++ /dev/null @@ -1,159 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -from typing import TypeAlias - -import dace -import dace.subsets as sbs - -from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_to_tasklet, - utility as dace_fieldview_util, -) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( - gtir_builtin, -) -from gt4py.next.type_system import type_specifications as ts - - -class GTIRBuiltinAsFieldOp(gtir_builtin.GTIRPrimitiveTranslator): - """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - - TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] - - stencil_expr: itir.Lambda - stencil_args: list[gtir_builtin.SDFGFieldBuilder] - field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] - field_type: ts.FieldType - offset_provider: dict[str, Connectivity | Dimension] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - node: itir.FunCall, - stencil_args: list[gtir_builtin.SDFGFieldBuilder], - offset_provider: dict[str, Connectivity | Dimension], - ): - super().__init__(sdfg, state) - self.offset_provider = offset_provider - - assert cpm.is_call_to(node.fun, "as_fieldop") - assert len(node.fun.args) == 2 - stencil_expr, domain_expr = node.fun.args - # expect stencil (represented as a lambda function) as first argument - assert isinstance(stencil_expr, itir.Lambda) - # the domain of the field operator is passed as second argument - assert isinstance(domain_expr, itir.FunCall) - - domain = dace_fieldview_util.get_domain(domain_expr) - # define field domain with all dimensions in alphabetical order - sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) - - # add local storage to compute the field operator over the given domain - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - - self.field_domain = domain - self.field_type = ts.FieldType(sorted_domain_dims, node_type) - self.stencil_expr = stencil_expr - self.stencil_args = stencil_args - - def build(self) -> list[gtir_builtin.SDFGField]: - dimension_index_fmt = "i_{dim}" - # type of variables used for field indexing - index_dtype = dace.int32 - # first visit the list of arguments and build a symbol map - stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in self.stencil_args: - arg_nodes = arg() - assert len(arg_nodes) == 1 - data_node, arg_type = arg_nodes[0] - # require all argument nodes to be data access nodes (no symbols) - assert isinstance(data_node, dace.nodes.AccessNode) - - if isinstance(arg_type, ts.ScalarType): - scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - stencil_args.append(scalar_arg) - else: - assert isinstance(arg_type, ts.FieldType) - indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { - dim.value: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), - index_dtype, - ) - for dim in self.field_domain.keys() - } - iterator_arg = gtir_to_tasklet.IteratorExpr( - data_node, - [dim.value for dim in arg_type.dims], - indices, - ) - stencil_args.append(iterator_arg) - - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - - # retrieve the tasklet node which writes the result - output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src - output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn - - # the last transient node can be deleted - self.head_state.remove_node(output_expr.node) - - # allocate local temporary storage for the result field - field_shape = [ - # diff between upper and lower bound - self.field_domain[dim][1] - self.field_domain[dim][0] - for dim in self.field_type.dims - ] - field_node = self.add_local_storage(self.field_type, field_shape) - - # assume tasklet with single output - output_index = ",".join( - dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims - ) - output_memlet = dace.Memlet(data=field_node.data, subset=output_index) - - # create map range corresponding to the field operator domain - map_ranges = { - dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" - for dim, (lb, ub) in self.field_domain.items() - } - me, mx = self.head_state.add_map("field_op", map_ranges) - - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - self.head_state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) - self.head_state.add_memlet_path( - output_tasklet_node, - mx, - field_node, - src_conn=output_tasklet_connector, - memlet=output_memlet, - ) - - return [(field_node, self.field_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py deleted file mode 100644 index b0786cd6d7..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_select.py +++ /dev/null @@ -1,122 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import dace - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( - gtir_builtin, -) - - -class GTIRBuiltinSelect(gtir_builtin.GTIRPrimitiveTranslator): - """Generates the dataflow subgraph for the `select` builtin function.""" - - true_br_builder: gtir_builtin.SDFGFieldBuilder - false_br_builder: gtir_builtin.SDFGFieldBuilder - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - dataflow_builder: eve.NodeVisitor, - node: itir.FunCall, - ): - super().__init__(sdfg, state) - - assert cpm.is_call_to(node.fun, "select") - assert len(node.fun.args) == 3 - cond_expr, true_expr, false_expr = node.fun.args - - # expect condition as first argument - cond = dace_fieldview_util.get_symbolic_expr(cond_expr) - - # use current head state to terminate the dataflow, and add a entry state - # to connect the true/false branch states as follows: - # - # ------------ - # === | select | === - # || ------------ || - # \/ \/ - # ------------ ------------- - # | true | | false | - # ------------ ------------- - # || || - # || ------------ || - # ==> | head | <== - # ------------ - # - select_state = sdfg.add_state_before(state, state.label + "_select") - sdfg.remove_edge(sdfg.out_edges(select_state)[0]) - - # expect true branch as second argument - true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) - sdfg.add_edge(true_state, state, dace.InterstateEdge()) - self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) - - # and false branch as third argument - false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}"))) - sdfg.add_edge(false_state, state, dace.InterstateEdge()) - self.false_br_builder = dataflow_builder.visit( - false_expr, sdfg=sdfg, head_state=false_state - ) - - def build(self) -> list[gtir_builtin.SDFGField]: - # retrieve true/false states as predecessors of head state - branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) - assert len(branch_states) == 2 - if branch_states[0].label.endswith("_true_branch"): - true_state, false_state = branch_states - else: - false_state, true_state = branch_states - - true_br_args = self.true_br_builder() - false_br_args = self.false_br_builder() - - output_nodes = [] - for true_br, false_br in zip(true_br_args, false_br_args, strict=True): - true_br_node, true_br_type = true_br - assert isinstance(true_br_node, dace.nodes.AccessNode) - false_br_node, false_br_type = false_br - assert isinstance(false_br_node, dace.nodes.AccessNode) - assert true_br_type == false_br_type - array_type = self.sdfg.arrays[true_br_node.data] - access_node = self.add_local_storage(true_br_type, array_type.shape) - output_nodes.append((access_node, true_br_type)) - - data_name = access_node.data - true_br_output_node = true_state.add_access(data_name) - true_state.add_nedge( - true_br_node, - true_br_output_node, - dace.Memlet.from_array( - true_br_output_node.data, true_br_output_node.desc(self.sdfg) - ), - ) - - false_br_output_node = false_state.add_access(data_name) - false_state.add_nedge( - false_br_node, - false_br_output_node, - dace.Memlet.from_array( - false_br_output_node.data, false_br_output_node.desc(self.sdfg) - ), - ) - return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py deleted file mode 100644 index e7ae631cfd..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators/gtir_builtin_symbol_ref.py +++ /dev/null @@ -1,65 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import dace - -from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtin_translators import ( - gtir_builtin, -) -from gt4py.next.type_system import type_specifications as ts - - -class GTIRBuiltinSymbolRef(gtir_builtin.GTIRPrimitiveTranslator): - """Generates the dataflow subgraph for a `itir.SymRef` node.""" - - sym_name: str - sym_type: ts.FieldType | ts.ScalarType - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - sym_name: str, - sym_type: ts.FieldType | ts.ScalarType, - ): - super().__init__(sdfg, state) - self.sym_name = sym_name - self.sym_type = sym_type - - def build(self) -> list[gtir_builtin.SDFGField]: - if isinstance(self.sym_type, ts.FieldType): - # add access node to current state - sym_node = self.head_state.add_access(self.sym_name) - - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - assert self.sym_name in self.sdfg.symbols - tasklet_node = self.head_state.add_tasklet( - f"get_{self.sym_name}", - {}, - {"__out"}, - f"__out = {self.sym_name}", - ) - sym_node = self.add_local_storage(self.sym_type, shape=[]) - self.head_state.add_edge( - tasklet_node, - "__out", - sym_node, - None, - dace.Memlet(data=sym_node.data, subset="0"), - ) - - return [(sym_node, self.sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 663fdc2513..a0f3f06f40 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -21,7 +21,7 @@ import dace -import gt4py.eve as eve +from gt4py import eve from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -68,7 +68,7 @@ def _make_array_shape_and_strides( the corresponding array shape dimension is set to an integer literal value. Returns: - Two list of symbols, one for the shape and another for the strides of the array. + Two lists of symbols, one for the shape and the other for the strides of the array. """ dtype = dace.int32 neighbor_tables = dace_fieldview_util.filter_connectivities(self.offset_provider) @@ -121,20 +121,22 @@ def _visit_expression( """ Specialized visit method for fieldview expressions. - This method represents the entry point to visit 'Stmt' expressions. + This method represents the entry point to visit `ir.Stmt` expressions. As such, it must preserve the property of single exit state in the SDFG. + Returns a list of array nodes containing the result fields. + TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? """ - expr_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( + field_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( node, sdfg=sdfg, head_state=head_state ) - results = expr_builder() + results = field_builder() - expressions_nodes = [] + field_nodes = [] for node, _ in results: assert isinstance(node, dace.nodes.AccessNode) - expressions_nodes.append(node) + field_nodes.append(node) # sanity check: each statement should preserve the property of single exit state (aka head state), # i.e. eventually only introduce internal branches, and keep the same head state @@ -142,7 +144,7 @@ def _visit_expression( assert len(sink_states) == 1 assert sink_states[0] == head_state - return expressions_nodes + return field_nodes def visit_Program(self, node: itir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. @@ -166,7 +168,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: temp_symbols |= self._add_storage_for_temporary(decl) # define symbols for shape and offsets of temporary arrays as interstate edge symbols - # TODO(edopao): use new `add_state_after` function in next dace release + # TODO(edopao): use new `add_state_after` function available in next dace release head_state = sdfg.add_state_after(entry_state, "init_temps") sdfg.edges_between(entry_state, head_state)[0].assignments = temp_symbols else: @@ -188,7 +190,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. - The translation of `SetAt` ensures that the result is written back to some global storage. + The translation of `SetAt` ensures that the result is written back to the target external storage. """ expr_nodes = self._visit_expression(stmt.expr, sdfg, state) @@ -230,15 +232,14 @@ def visit_FunCall( ) arg_builders.append(arg_builder) + # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): return gtir_builtin_translators.AsFieldOp( sdfg, head_state, node, arg_builders, self.offset_provider ) - elif cpm.is_call_to(node.fun, "select"): assert len(arg_builders) == 0 return gtir_builtin_translators.Select(sdfg, head_state, self, node) - else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 5c06694315..ae38dfa5b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -48,7 +48,7 @@ class SymbolExpr: @dataclass(frozen=True) class ValueExpr: - """Result of the computation provided by a tasklet node.""" + """Result of the computation implemented by a tasklet node.""" node: dace.nodes.AccessNode @@ -65,6 +65,7 @@ class IteratorExpr: indices: dict[str, IteratorIndexExpr] +# Define alias for the elements needed to setup input connections to a map scope InputConnection: TypeAlias = tuple[ dace.nodes.AccessNode, sbs.Range, @@ -129,7 +130,12 @@ class IteratorExpr: class LambdaToTasklet(eve.NodeVisitor): - """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + """Translates an `ir.Lambda` expression to a dataflow graph. + + Lambda functions should only be encountered as argument to the `as_field_op` + builtin function, therefore the dataflow graph generated here typically + represents the stencil function of a field operator. + """ sdfg: dace.SDFG state: dace.SDFGState @@ -281,6 +287,12 @@ def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | MemletExpr | SymbolE class PythonCodegen(codegen.TemplatedGenerator): + """Helper class to visit a symbolic expression and translate it to Python code. + + The generated Python code can be use either as the body of a tasklet node or, + as in the case of field domain definitions, for sybolic array shape and map range. + """ + SymRef = as_fmt("{id}") Literal = as_fmt("{value}") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 99c46b2f1b..fbf2f16aa8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -60,7 +60,7 @@ def get_domain( """ Specialized visit method for domain expressions. - Returns a list of dimensions and the corresponding range. + Returns for each domain dimension the corresponding range. """ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) @@ -82,4 +82,10 @@ def get_domain( def get_symbolic_expr(node: itir.Expr) -> str: + """ + Specialized visit method for symbolic expressions. + + Returns a string containong the corresponding Python code, which as tasklet body + or symbolic array shape. + """ return gtir_to_tasklet.PythonCodegen().visit(node) From 82fdf64c2a2741a7ffad4be6220a39c339dbc219 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 11:09:52 +0200 Subject: [PATCH 067/235] Add wrapper function to build SDFG --- .../runners/dace_fieldview/__init__.py | 10 +++ .../dace_fieldview/gtir_dace_backend.py | 35 ++++++++ .../runners_tests/test_dace_fieldview.py | 85 +++++++------------ 3 files changed, 74 insertions(+), 56 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py index 6c43e2f12a..18a753a17c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -11,3 +11,13 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + + +from gt4py.next.program_processors.runners.dace_fieldview.gtir_dace_backend import ( + build_sdfg_from_gtir, +) + + +__all__ = [ + "build_sdfg_from_gtir", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py new file mode 100644 index 0000000000..c8c798292a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -0,0 +1,35 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dace + +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_to_sdfg as gtir_dace_translator, +) +from gt4py.next.type_system import type_specifications as ts + + +def build_sdfg_from_gtir( + program: itir.Program, + arg_types: list[ts.DataType], + offset_provider: dict[str, Connectivity | Dimension], +) -> dace.SDFG: + sdfg_genenerator = gtir_dace_translator.GTIRToSDFG(arg_types, offset_provider) + sdfg = sdfg_genenerator.visit(program) + assert isinstance(sdfg, dace.SDFG) + + sdfg.simplify() + return sdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index c5c45fac3d..c0f3960ee8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -19,9 +19,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import ( - GTIRToSDFG as FieldviewGtirToSDFG, -) +from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import IDim @@ -72,11 +70,8 @@ def test_gtir_copy(): a = np.random.rand(N) b = np.empty_like(a) - sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, **FSYMBOLS) assert np.allclose(a, b) @@ -108,11 +103,8 @@ def test_gtir_update(): a = np.random.rand(N) ref = a + 1.0 - sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, **FSYMBOLS) assert np.allclose(a, ref) @@ -145,11 +137,8 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], offset_provider={} - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, z=c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -181,12 +170,8 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], - offset_provider={}, - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + a)) @@ -223,10 +208,7 @@ def test_gtir_sum3(): b = np.random.rand(N) c = np.random.rand(N) - sdfg_genenerator = FieldviewGtirToSDFG( - [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)], - offset_provider={}, - ) + arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] for i, stencil in enumerate([stencil1, stencil2]): testee = itir.Program( @@ -249,8 +231,7 @@ def test_gtir_sum3(): ], ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) d = np.empty_like(a) @@ -312,20 +293,16 @@ def test_gtir_select(): b = np.random.rand(N) c = np.random.rand(N) - sdfg_genenerator = FieldviewGtirToSDFG( - [ - IFTYPE, - IFTYPE, - IFTYPE, - IFTYPE, - ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.FLOAT64), - ts.ScalarType(ts.ScalarKind.INT32), - ], - offset_provider={}, - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [ + IFTYPE, + IFTYPE, + IFTYPE, + IFTYPE, + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.FLOAT64), + ts.ScalarType(ts.ScalarKind.INT32), + ] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) for s in [False, True]: d = np.empty_like(a) @@ -386,18 +363,14 @@ def test_gtir_select_nested(): a = np.random.rand(N) - sdfg_genenerator = FieldviewGtirToSDFG( - [ - IFTYPE, - IFTYPE, - ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.INT32), - ], - offset_provider={}, - ) - sdfg = sdfg_genenerator.visit(testee) - assert isinstance(sdfg, dace.SDFG) + arg_types = [ + IFTYPE, + IFTYPE, + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.BOOL), + ts.ScalarType(ts.ScalarKind.INT32), + ] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) for s1 in [False, True]: for s2 in [False, True]: From 51aaf0f84155ea44c87cacb64eb6a07b9dff35df Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 14:23:09 +0200 Subject: [PATCH 068/235] Add fieldview flavor of all test cases --- .../gtir_builtin_translators.py | 30 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 5 + .../runners/dace_fieldview/gtir_to_tasklet.py | 19 +- .../runners/dace_fieldview/utility.py | 9 + .../runners_tests/test_dace_fieldview.py | 276 +++++++++++++----- 5 files changed, 258 insertions(+), 81 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e71a7e5606..657b2b9c43 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -169,6 +169,8 @@ def build(self) -> list[SDFGField]: output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn # the last transient node can be deleted + # TODO: not needed to store the node `dtype` after type inference is in place + dtype = output_expr.node.desc(self.sdfg).dtype self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field @@ -177,7 +179,11 @@ def build(self) -> list[SDFGField]: self.field_domain[dim][1] - self.field_domain[dim][0] for dim in self.field_type.dims ] - field_node = self.add_local_storage(self.field_type, field_shape) + # TODO: use `self.field_type` without overriding `dtype` when type inference is in place + field_dtype = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + field_node = self.add_local_storage( + ts.FieldType(self.field_type.dims, field_dtype), field_shape + ) # assume tasklet with single output output_index = ",".join( @@ -192,15 +198,19 @@ def build(self) -> list[SDFGField]: } me, mx = self.head_state.add_map("field_op", map_ranges) - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - self.head_state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + self.head_state.add_nedge(me, output_tasklet_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + self.head_state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) self.head_state.add_memlet_path( output_tasklet_node, mx, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 929cc6acfd..f3527eadd3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -160,6 +160,11 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + if len(node.params) != len(self.param_types): + raise RuntimeError( + "The provided list of parameter types has different length than SDFG parameter list." + ) + sdfg = dace.SDFG(node.id) entry_state = sdfg.add_state("program_entry", is_start_block=True) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 17eced22ad..44a9801ebd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -589,14 +589,23 @@ def visit_Lambda( ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr: MemletExpr | ValueExpr = self.visit(node.expr) + output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): return self.input_connections, output_expr - # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.source.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.source.desc(self.sdfg).dtype + tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_connection( + output_expr.source, output_expr.subset, tasklet_node, "__inp" + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dtype + tasklet_node = self.state.add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 5701aac395..8777ac3eda 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -40,6 +40,15 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") +def as_scalar_type(typestr: str) -> ts.ScalarType: + """Obtain GT4Py scalar type from generic numpy string representation.""" + try: + kind = getattr(ts.ScalarKind, typestr.upper()) + except AttributeError as ex: + raise ValueError(f"Data type {typestr} not supported.") from ex + return ts.ScalarType(kind) + + def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index f20f9263e7..ee70a8eab1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -43,6 +43,7 @@ CFTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) SIMPLE_MESH: MeshDescriptor = simple_mesh() FSYMBOLS = dict( __w_size_0=N, @@ -110,7 +111,7 @@ def test_gtir_copy(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, **FSYMBOLS) @@ -143,7 +144,7 @@ def test_gtir_update(): a = np.random.rand(N) ref = a + 1.0 - arg_types = [IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, **FSYMBOLS) @@ -177,7 +178,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, z=c, **FSYMBOLS) @@ -210,7 +211,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, z=b, **FSYMBOLS) @@ -248,7 +249,7 @@ def test_gtir_sum3(): b = np.random.rand(N) c = np.random.rand(N) - arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] for i, stencil in enumerate([stencil1, stencil2]): testee = itir.Program( @@ -340,7 +341,7 @@ def test_gtir_select(): IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.FLOAT64), - ts.ScalarType(ts.ScalarKind.INT32), + SIZE_TYPE, ] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) @@ -408,7 +409,7 @@ def test_gtir_select_nested(): IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.INT32), + SIZE_TYPE, ] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) @@ -427,15 +428,35 @@ def test_gtir_cartesian_shift(): ) # cartesian shift with literal integer offset - stencil1 = im.call( + stencil1_inlined = im.call( im.call("as_fieldop")( im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), domain, ) )("x") + # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value + stencil1_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + domain, + ) + )("x"), + im.call( + im.call("as_fieldop")( + im.lambda_()(DELTA), + domain, + ) + )(), + ) # use dynamic offset retrieved from field - stencil2 = im.call( + stencil2_inlined = im.call( im.call("as_fieldop")( im.lambda_("a", "off")( im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA) @@ -443,9 +464,29 @@ def test_gtir_cartesian_shift(): domain, ) )("x", "x_offset") + # fieldview flavor of same stencil + stencil2_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + domain, + ) + )("x", "x_offset"), + im.call( + im.call("as_fieldop")( + im.lambda_()(DELTA), + domain, + ) + )(), + ) # use the result of an arithmetic field operation as dynamic offset - stencil3 = im.call( + stencil3_inlined = im.call( im.call("as_fieldop")( im.lambda_("a", "off")( im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) @@ -453,6 +494,34 @@ def test_gtir_cartesian_shift(): domain, ) )("x", "x_offset") + # fieldview flavor of same stencil + stencil3_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + domain, + ) + )( + "x", + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.plus(im.deref("it"), 0)), + domain, + ) + )("x_offset"), + ), + im.call( + im.call("as_fieldop")( + im.lambda_()(DELTA), + domain, + ) + )(), + ) a = np.random.rand(N + OFFSET) a_offset = np.full(N, OFFSET, dtype=np.int32) @@ -460,10 +529,19 @@ def test_gtir_cartesian_shift(): IOFFSET_FTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) - arg_types = [IFTYPE, IOFFSET_FTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IOFFSET_FTYPE, IFTYPE, SIZE_TYPE] offset_provider = {"IDim": IDim} - for i, stencil in enumerate([stencil1, stencil2, stencil3]): + for i, stencil in enumerate( + [ + stencil1_inlined, + stencil1_fieldview, + stencil2_inlined, + stencil2_fieldview, + stencil3_inlined, + stencil3_fieldview, + ] + ): testee = itir.Program( id=f"cartesian_shift_{i}", function_definitions=[], @@ -498,8 +576,12 @@ def test_gtir_connectivity_shift(): cell_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Cell.value), 0, "ncells"), ) + temp_cv_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Cell.value), 0, "ncells"), + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) # apply shift 2 times along different dimensions - stencil1 = im.call( + stencil1_inlined = im.call( im.call("as_fieldop")( im.lambda_("it")( im.deref(im.shift("C2V", C2V_neighbor_idx)(im.shift("C2E", C2E_neighbor_idx)("it"))) @@ -507,6 +589,20 @@ def test_gtir_connectivity_shift(): cell_domain, ) )("ve_field") + # fieldview flavor of the same stncil: create an intermediate temporary field + stencil1_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("C2V", C2V_neighbor_idx)("it"))), + cell_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("C2E", C2E_neighbor_idx)("it"))), + temp_cv_domain, + ) + )("ve_field") + ) # multi-dimensional shift in one function call stencil2 = im.call( @@ -528,16 +624,16 @@ def test_gtir_connectivity_shift(): )("ve_field") # again multi-dimensional shift in one function call, but this time with dynamic offset values - stencil3 = im.call( + stencil3_inlined = im.call( im.call("as_fieldop")( im.lambda_("it", "c2e_off", "c2v_off")( im.deref( im.call( im.call("shift")( im.ensure_offset("C2V"), - im.deref("c2v_off"), + im.plus(im.deref("c2v_off"), 0), im.ensure_offset("C2E"), - im.plus(im.deref("c2e_off"), 0), + im.deref("c2e_off"), ) )("it") ) @@ -545,28 +641,59 @@ def test_gtir_connectivity_shift(): cell_domain, ) )("ve_field", "c2e_offset", "c2v_offset") + # fieldview flavor of same stencil with dynamic offset + stencil3_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it", "c2e_off", "c2v_off")( + im.deref( + im.call( + im.call("shift")( + im.ensure_offset("C2V"), + im.deref("c2v_off"), + im.ensure_offset("C2E"), + im.deref("c2e_off"), + ) + )("it") + ) + ), + cell_domain, + ) + )( + "ve_field", + "c2e_offset", + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.plus(im.deref("it"), 0)), + cell_domain, + ) + )("c2v_offset"), + ) - ve = np.random.rand(SIMPLE_MESH.num_vertices, SIMPLE_MESH.num_edges) VE_FTYPE = ts.FieldType(dims=[Vertex, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) arg_types = [ VE_FTYPE, CELL_OFFSET_FTYPE, CELL_OFFSET_FTYPE, CFTYPE, - ts.ScalarType(ts.ScalarKind.INT32), + SIZE_TYPE, + SIZE_TYPE, ] connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, NeighborTable) connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, NeighborTable) + + ve = np.random.rand(SIMPLE_MESH.num_vertices, SIMPLE_MESH.num_edges) ref = ve[ connectivity_C2V.table[:, C2V_neighbor_idx], connectivity_C2E.table[:, C2E_neighbor_idx] ] - for i, stencil in enumerate([stencil1, stencil2, stencil3]): + for i, stencil in enumerate( + [stencil1_inlined, stencil1_fieldview, stencil2, stencil3_inlined, stencil3_fieldview] + ): testee = itir.Program( id=f"connectivity_shift_2d_{i}", function_definitions=[], @@ -576,6 +703,7 @@ def test_gtir_connectivity_shift(): itir.Sym(id="c2v_offset"), itir.Sym(id="cells"), itir.Sym(id="ncells"), + itir.Sym(id="nvertices"), ], declarations=[], body=[ @@ -616,58 +744,74 @@ def test_gtir_connectivity_shift_chain(): edge_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "nedges") ) - testee = itir.Program( - id="connectivity_shift_chain", - function_definitions=[], - params=[ - itir.Sym(id="edges"), - itir.Sym(id="edges_out"), - itir.Sym(id="nedges"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.deref( - im.shift("E2V", E2V_neighbor_idx)( - im.shift("V2E", V2E_neighbor_idx)("it") - ) - ) - ), - edge_domain, - ) - )("edges"), - domain=edge_domain, - target=itir.SymRef(id="edges_out"), + temp_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nedges") + ) + # double indirection using fieldview representation + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("E2V", E2V_neighbor_idx)("it"))), + edge_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", V2E_neighbor_idx)("it"))), + temp_domain, ) - ], + )("edges") ) + # iterator flavor of same stencil + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.deref(im.shift("E2V", E2V_neighbor_idx)(im.shift("V2E", V2E_neighbor_idx)("it"))) + ), + edge_domain, + ) + )("edges") - e = np.random.rand(SIMPLE_MESH.num_edges) - e_out = np.empty_like(e) - - arg_types = [EFTYPE, EFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + arg_types = [EFTYPE, EFTYPE, SIZE_TYPE] - assert isinstance(sdfg, dace.SDFG) connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, NeighborTable) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, NeighborTable) - sdfg( - edges=e, - edges_out=e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **CSYMBOLS, - __edges_out_size_0=CSYMBOLS["__edges_size_0"], - __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], - ) - assert np.allclose( - e_out, - e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]], - ) + e = np.random.rand(SIMPLE_MESH.num_edges) + ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + + for stencil in [stencil_fieldview, stencil_inlined]: + testee = itir.Program( + id="connectivity_shift_chain", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="edges_out"), + itir.Sym(id="nedges"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=edge_domain, + target=itir.SymRef(id="edges_out"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + + # new empty output field + e_out = np.empty_like(e) + + sdfg( + edges=e, + edges_out=e_out, + connectivity_E2V=connectivity_E2V.table, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __edges_out_size_0=CSYMBOLS["__edges_size_0"], + __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], + ) + assert np.allclose(e_out, ref) From 6ccecf174b6e35e6e5796bc9350566762f2ab124 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 14:27:25 +0200 Subject: [PATCH 069/235] Code changes imported from branch dace-fieldview-shifts --- .../gtir_builtin_translators.py | 30 ++++++++++++------- .../runners/dace_fieldview/gtir_to_sdfg.py | 5 ++++ .../runners/dace_fieldview/gtir_to_tasklet.py | 19 ++++++++---- .../runners/dace_fieldview/utility.py | 9 ++++++ .../runners_tests/test_dace_fieldview.py | 15 +++++----- 5 files changed, 56 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e71a7e5606..657b2b9c43 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -169,6 +169,8 @@ def build(self) -> list[SDFGField]: output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn # the last transient node can be deleted + # TODO: not needed to store the node `dtype` after type inference is in place + dtype = output_expr.node.desc(self.sdfg).dtype self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field @@ -177,7 +179,11 @@ def build(self) -> list[SDFGField]: self.field_domain[dim][1] - self.field_domain[dim][0] for dim in self.field_type.dims ] - field_node = self.add_local_storage(self.field_type, field_shape) + # TODO: use `self.field_type` without overriding `dtype` when type inference is in place + field_dtype = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + field_node = self.add_local_storage( + ts.FieldType(self.field_type.dims, field_dtype), field_shape + ) # assume tasklet with single output output_index = ",".join( @@ -192,15 +198,19 @@ def build(self) -> list[SDFGField]: } me, mx = self.head_state.add_map("field_op", map_ranges) - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - self.head_state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + self.head_state.add_nedge(me, output_tasklet_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + self.head_state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) self.head_state.add_memlet_path( output_tasklet_node, mx, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index a0f3f06f40..43c4fce1a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -158,6 +158,11 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + if len(node.params) != len(self.param_types): + raise RuntimeError( + "The provided list of parameter types has different length than SDFG parameter list." + ) + sdfg = dace.SDFG(node.id) entry_state = sdfg.add_state("program_entry", is_start_block=True) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index ae38dfa5b0..6758c8b497 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -266,14 +266,23 @@ def visit_Lambda( ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr: MemletExpr | ValueExpr = self.visit(node.expr) + output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): return self.input_connections, output_expr - # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.source.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection(output_expr.source, output_expr.subset, tasklet_node, "__inp") + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.source.desc(self.sdfg).dtype + tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_connection( + output_expr.source, output_expr.subset, tasklet_node, "__inp" + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dtype + tasklet_node = self.state.add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index fbf2f16aa8..366b45e199 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -40,6 +40,15 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") +def as_scalar_type(typestr: str) -> ts.ScalarType: + """Obtain GT4Py scalar type from generic numpy string representation.""" + try: + kind = getattr(ts.ScalarKind, typestr.upper()) + except AttributeError as ex: + raise ValueError(f"Data type {typestr} not supported.") from ex + return ts.ScalarType(kind) + + def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: """ Filter offset providers of type `Connectivity`. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index c0f3960ee8..3344dc5bf5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -31,6 +31,7 @@ N = 10 IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, __w_stride_0=1, @@ -70,7 +71,7 @@ def test_gtir_copy(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, **FSYMBOLS) @@ -103,7 +104,7 @@ def test_gtir_update(): a = np.random.rand(N) ref = a + 1.0 - arg_types = [IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, **FSYMBOLS) @@ -137,7 +138,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, y=b, z=c, **FSYMBOLS) @@ -170,7 +171,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) sdfg(x=a, z=b, **FSYMBOLS) @@ -208,7 +209,7 @@ def test_gtir_sum3(): b = np.random.rand(N) c = np.random.rand(N) - arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, ts.ScalarType(ts.ScalarKind.INT32)] + arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] for i, stencil in enumerate([stencil1, stencil2]): testee = itir.Program( @@ -300,7 +301,7 @@ def test_gtir_select(): IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.FLOAT64), - ts.ScalarType(ts.ScalarKind.INT32), + SIZE_TYPE, ] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) @@ -368,7 +369,7 @@ def test_gtir_select_nested(): IFTYPE, ts.ScalarType(ts.ScalarKind.BOOL), ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.INT32), + SIZE_TYPE, ] sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) From e66b9605c38f3478465e74b978733d3eb23c7255 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 May 2024 15:13:43 +0200 Subject: [PATCH 070/235] Code comments updated --- .../runners/dace_fieldview/gtir_to_tasklet.py | 114 +++++++++++------- 1 file changed, 70 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 44a9801ebd..4c0cfd403f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -187,11 +187,12 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: if isinstance(it, IteratorExpr): if all(isinstance(index, SymbolExpr) for index in it.indices.values()): - # use direct field access through memlet subset + # when all indices are symblic expressions, we can perform direct field access through a memlet data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] return MemletExpr(it.field, data_index) else: + # we use a tasklet to perform dereferencing of a generic iterator assert all(dim in it.indices.keys() for dim in it.dimensions) field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_connectors = [ @@ -247,12 +248,17 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: + """ + Splits the arguments to `shift` builtin function as pairs, each pair containing + the offset provider and the offset value in one dimension. + """ pairs = [args[i : i + 2] for i in range(0, len(args), 2)] assert len(pairs) >= 1 assert all(len(pair) == 2 for pair in pairs) return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None def _make_shift_for_rest(self, rest: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: + """Transforms a multi-dimensional shift into recursive shift calls, each in a single dimension.""" return itir.FunCall( fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator], @@ -261,12 +267,13 @@ def _make_shift_for_rest(self, rest: list[itir.Expr], iterator: itir.Expr) -> it def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: Dimension, offset_expr: IteratorIndexExpr ) -> IteratorExpr: - """Implements cartesian offset along one dimension.""" + """Implements cartesian shift along one dimension.""" assert offset_dim.value in it.dimensions new_index: SymbolExpr | ValueExpr assert offset_dim.value in it.indices index_expr = it.indices[offset_dim.value] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): + # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr(index_expr.value + offset_expr.value, index_expr.dtype) else: # the offset needs to be calculate by means of a tasklet @@ -319,6 +326,7 @@ def _make_cartesian_shift( new_index = self._get_tasklet_result(dtype, dynamic_offset_tasklet, new_index_connector) + # a new iterator with a shifted index along one dimension return IteratorExpr( it.field, it.dimensions, @@ -328,6 +336,50 @@ def _make_cartesian_shift( }, ) + def _make_dynamic_neighbor_offset( + self, + offset_expr: MemletExpr | ValueExpr, + offset_table_node: dace.nodes.AccessNode, + origin_index: SymbolExpr, + ) -> ValueExpr: + """ + Implements access to neighbor connectivity table by means of a tasklet node. + + It requires a dynamic offset value, either obtained from a field (`MemletExpr`) + or computed byanother tasklet (`ValueExpr`). + """ + new_index_connector = "neighbor_index" + tasklet_node = self.state.add_tasklet( + "dynamic_neighbor_offset", + {"table", "offset"}, + {new_index_connector}, + f"{new_index_connector} = table[{origin_index.value}, offset]", + ) + self._add_input_connection( + offset_table_node, + sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + tasklet_node, + "table", + ) + if isinstance(offset_expr, MemletExpr): + self._add_input_connection( + offset_expr.source, + offset_expr.subset, + tasklet_node, + "offset", + ) + else: + self.state.add_edge( + offset_expr.node, + None, + tasklet_node, + "offset", + dace.Memlet(data=offset_expr.node.data, subset="0"), + ) + + dtype = offset_table_node.desc(self.sdfg).dtype + return self._get_tasklet_result(dtype, tasklet_node, new_index_connector) + def _make_unstructured_shift( self, it: IteratorExpr, @@ -335,18 +387,23 @@ def _make_unstructured_shift( offset_table_node: dace.nodes.AccessNode, offset_expr: IteratorIndexExpr, ) -> IteratorExpr: - # shift in unstructured domain by means of a neighbor table + """Implements shift in unstructured domain by means of a neighbor table.""" neighbor_dim = connectivity.neighbor_axis.value assert neighbor_dim in it.dimensions origin_dim = connectivity.origin_axis.value if origin_dim in it.indices: + # this is the regular case, where the index in the origin dimension is known origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) neighbor_expr = it.indices.get(neighbor_dim, None) if neighbor_expr is not None: + # This branch should be executed for chained shift, like `as_fieldop(λ(it) → ·⟪E2Vₒ, 1ₒ⟫(⟪V2Eₒ, 2ₒ⟫(it)))(edges)` + # More specifically, here we are visiting the E2V shift of this example. We have already built the tasklet node to perform + # V2E neighbor table access but we are missing the value in the origin dimension (the `Vertex` dimension in this example). + # TODO: This branch can be deleted in pure fieldview GTIR assert isinstance(neighbor_expr, ValueExpr) - # retrieve the tasklet that perform the neighbor table access + # retrieve the tasklet that performs the neighbor table access neighbor_tasklet_node = self.state.in_edges(neighbor_expr.node)[0].src if isinstance(offset_expr, SymbolExpr): # use memlet to retrieve the neighbor index and pass it to the index connector of tasklet for neighbor access @@ -392,6 +449,11 @@ def _make_unstructured_shift( shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} else: + # Here the index in the origin dimension is not known: this case should only be encountered with chianed indirection, + # for example `as_fieldop(λ(it) → ·⟪E2Vₒ, 1ₒ⟫(⟪V2Eₒ, 2ₒ⟫(it)))(edges)` + # More precisely, we are visiting the shift expression for V2E offset: we build a tasklet to compute the neighbor index + # but we leave `origin_index_connector` pending (for `Vertex` dimension in this example). + # TODO: This branch can be deleted in pure fieldview GTIR origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) if isinstance(offset_expr, SymbolExpr): @@ -443,53 +505,17 @@ def _make_unstructured_shift( shifted_indices, ) - def _make_dynamic_neighbor_offset( - self, - offset_expr: MemletExpr | ValueExpr, - offset_table_node: dace.nodes.AccessNode, - origin_index: SymbolExpr, - ) -> ValueExpr: - new_index_connector = "neighbor_index" - tasklet_node = self.state.add_tasklet( - "dynamic_neighbor_offset", - {"table", "offset"}, - {new_index_connector}, - f"{new_index_connector} = table[{origin_index.value}, offset]", - ) - self._add_input_connection( - offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), - tasklet_node, - "table", - ) - if isinstance(offset_expr, MemletExpr): - self._add_input_connection( - offset_expr.source, - offset_expr.subset, - tasklet_node, - "offset", - ) - else: - self.state.add_edge( - offset_expr.node, - None, - tasklet_node, - "offset", - dace.Memlet(data=offset_expr.node.data, subset="0"), - ) - - dtype = offset_table_node.desc(self.sdfg).dtype - return self._get_tasklet_result(dtype, tasklet_node, new_index_connector) - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun assert isinstance(shift_node, itir.FunCall) - # the iterator to be shifted is the argument to the function node + # here we check the arguments to the `shift` builtin function: the offset provider and the offset value head, tail = self._split_shift_args(shift_node.args) if tail: + # we visit a multi-dimensional shift as recursive shift function calls, each returning a new iterator it = self.visit(self._make_shift_for_rest(tail, node.args[0])) else: + # the iterator to be shifted is the argument to the function node it = self.visit(node.args[0]) assert isinstance(it, IteratorExpr) @@ -510,7 +536,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: if isinstance(offset_provider, Dimension): return self._make_cartesian_shift(it, offset_provider, offset_expr) else: - # initially, the storage for the connectivty tables is created as transient + # initially, the storage for the connectivty tables is created as transient; # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller offset_table = dace_fieldview_util.connectivity_identifier(offset) From 4c190bdf391083be1e74e288d009a88f07f45a50 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 21 May 2024 08:24:33 +0200 Subject: [PATCH 071/235] Remove support for inlined chained shift --- .../runners/dace_fieldview/gtir_to_tasklet.py | 126 +++--------------- .../runners_tests/test_dace_fieldview.py | 104 +++++++-------- 2 files changed, 64 insertions(+), 166 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 4c0cfd403f..4da021a6ad 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -390,120 +390,30 @@ def _make_unstructured_shift( """Implements shift in unstructured domain by means of a neighbor table.""" neighbor_dim = connectivity.neighbor_axis.value assert neighbor_dim in it.dimensions + assert neighbor_dim not in it.indices origin_dim = connectivity.origin_axis.value - if origin_dim in it.indices: - # this is the regular case, where the index in the origin dimension is known - origin_index = it.indices[origin_dim] - assert isinstance(origin_index, SymbolExpr) - neighbor_expr = it.indices.get(neighbor_dim, None) - if neighbor_expr is not None: - # This branch should be executed for chained shift, like `as_fieldop(λ(it) → ·⟪E2Vₒ, 1ₒ⟫(⟪V2Eₒ, 2ₒ⟫(it)))(edges)` - # More specifically, here we are visiting the E2V shift of this example. We have already built the tasklet node to perform - # V2E neighbor table access but we are missing the value in the origin dimension (the `Vertex` dimension in this example). - # TODO: This branch can be deleted in pure fieldview GTIR - assert isinstance(neighbor_expr, ValueExpr) - # retrieve the tasklet that performs the neighbor table access - neighbor_tasklet_node = self.state.in_edges(neighbor_expr.node)[0].src - if isinstance(offset_expr, SymbolExpr): - # use memlet to retrieve the neighbor index and pass it to the index connector of tasklet for neighbor access - self._add_input_connection( - offset_table_node, - sbs.Indices([origin_index.value, offset_expr.value]), - neighbor_tasklet_node, - INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), - ) - else: - # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - dynamic_offset_value = self._make_dynamic_neighbor_offset( - offset_expr, offset_table_node, origin_index - ) - - # write result to the index connector of tasklet for neighbor access - self.state.add_edge( - dynamic_offset_value.node, - None, - neighbor_tasklet_node, - INDEX_CONNECTOR_FMT.format(dim=neighbor_dim), - memlet=dace.Memlet(data=dynamic_offset_value.node.data, subset="0"), - ) - - shifted_indices = { - dim: index for dim, index in it.indices.items() if dim != neighbor_dim - } | {origin_dim: it.indices[neighbor_dim]} - - elif isinstance(offset_expr, SymbolExpr): - # use memlet to retrieve the neighbor index - shifted_indices = it.indices | { - neighbor_dim: MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_expr.value]), - ) - } - else: - # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - dynamic_offset_value = self._make_dynamic_neighbor_offset( - offset_expr, offset_table_node, origin_index + assert origin_dim in it.indices + origin_index = it.indices[origin_dim] + assert isinstance(origin_index, SymbolExpr) + + if isinstance(offset_expr, SymbolExpr): + # use memlet to retrieve the neighbor index + shifted_indices = it.indices | { + neighbor_dim: MemletExpr( + offset_table_node, + sbs.Indices([origin_index.value, offset_expr.value]), ) - - shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} - + } else: - # Here the index in the origin dimension is not known: this case should only be encountered with chianed indirection, - # for example `as_fieldop(λ(it) → ·⟪E2Vₒ, 1ₒ⟫(⟪V2Eₒ, 2ₒ⟫(it)))(edges)` - # More precisely, we are visiting the shift expression for V2E offset: we build a tasklet to compute the neighbor index - # but we leave `origin_index_connector` pending (for `Vertex` dimension in this example). - # TODO: This branch can be deleted in pure fieldview GTIR - origin_index_connector = INDEX_CONNECTOR_FMT.format(dim=origin_dim) - neighbor_index_connector = INDEX_CONNECTOR_FMT.format(dim=neighbor_dim) - if isinstance(offset_expr, SymbolExpr): - tasklet_node = self.state.add_tasklet( - "shift", - {"table", origin_index_connector}, - {neighbor_index_connector}, - f"{neighbor_index_connector} = table[{origin_index_connector}, {offset_expr.value}]", - ) - else: - tasklet_node = self.state.add_tasklet( - "shift", - {"table", origin_index_connector, "offset"}, - {neighbor_index_connector}, - f"{neighbor_index_connector} = table[{origin_index_connector}, offset]", - ) - if isinstance(offset_expr, MemletExpr): - self._add_input_connection( - offset_expr.source, - offset_expr.subset, - tasklet_node, - "offset", - ) - else: - self.state.add_edge( - offset_expr.node, - None, - tasklet_node, - "offset", - dace.Memlet(data=offset_expr.node.data, subset="0"), - ) - table_desc = offset_table_node.desc(self.sdfg) - neighbor_expr = self._get_tasklet_result( - table_desc.dtype, - tasklet_node, - neighbor_index_connector, - ) - self._add_input_connection( - offset_table_node, - sbs.Range.from_array(table_desc), - tasklet_node, - "table", + # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node + dynamic_offset_value = self._make_dynamic_neighbor_offset( + offset_expr, offset_table_node, origin_index ) - shifted_indices = it.indices | {origin_dim: neighbor_expr} - return IteratorExpr( - it.field, - [origin_dim if neighbor_expr and dim == neighbor_dim else dim for dim in it.dimensions], - shifted_indices, - ) + shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} + + return IteratorExpr(it.field, it.dimensions, shifted_indices) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 9897569e97..7c1cc5dd7f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -744,34 +744,42 @@ def test_gtir_connectivity_shift_chain(): edge_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "nedges") ) - temp_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nedges") + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices") ) - # double indirection using fieldview representation - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("E2V", E2V_neighbor_idx)("it"))), - edge_domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", V2E_neighbor_idx)("it"))), - temp_domain, + testee = itir.Program( + id="connectivity_shift_chain", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="edges_out"), + itir.Sym(id="nedges"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("E2V", E2V_neighbor_idx)("it"))), + edge_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", V2E_neighbor_idx)("it"))), + vertex_domain, + ) + )("edges") + ), + domain=edge_domain, + target=itir.SymRef(id="edges_out"), ) - )("edges") + ], ) - # iterator flavor of same stencil - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.deref(im.shift("E2V", E2V_neighbor_idx)(im.shift("V2E", V2E_neighbor_idx)("it"))) - ), - edge_domain, - ) - )("edges") - arg_types = [EFTYPE, EFTYPE, SIZE_TYPE] + arg_types = [EFTYPE, EFTYPE, SIZE_TYPE, SIZE_TYPE] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, NeighborTable) @@ -781,37 +789,17 @@ def test_gtir_connectivity_shift_chain(): e = np.random.rand(SIMPLE_MESH.num_edges) ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] - for stencil in [stencil_fieldview, stencil_inlined]: - testee = itir.Program( - id="connectivity_shift_chain", - function_definitions=[], - params=[ - itir.Sym(id="edges"), - itir.Sym(id="edges_out"), - itir.Sym(id="nedges"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil, - domain=edge_domain, - target=itir.SymRef(id="edges_out"), - ) - ], - ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) - - # new empty output field - e_out = np.empty_like(e) - - sdfg( - edges=e, - edges_out=e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **CSYMBOLS, - __edges_out_size_0=CSYMBOLS["__edges_size_0"], - __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], - ) - assert np.allclose(e_out, ref) + # new empty output field + e_out = np.empty_like(e) + + sdfg( + edges=e, + edges_out=e_out, + connectivity_E2V=connectivity_E2V.table, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __edges_out_size_0=CSYMBOLS["__edges_size_0"], + __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], + ) + assert np.allclose(e_out, ref) From 6052de2ed5a9a2636a6894c06631812efa83aec7 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 21 May 2024 14:49:06 +0200 Subject: [PATCH 072/235] Add support for neighbors builtin --- .../gtir_builtin_translators.py | 36 ++-- .../runners/dace_fieldview/gtir_to_sdfg.py | 3 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 170 ++++++++++++++++-- .../runners/dace_fieldview/utility.py | 7 +- .../runners_tests/test_dace_fieldview.py | 61 +++++++ 5 files changed, 243 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 657b2b9c43..46ece624a2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -91,7 +91,7 @@ class AsFieldOp(PrimitiveTranslator): stencil_expr: itir.Lambda stencil_args: list[SDFGFieldBuilder] - field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] + field_domain: dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] field_type: ts.FieldType offset_provider: dict[str, Connectivity | Dimension] @@ -116,7 +116,9 @@ def __init__( domain = dace_fieldview_util.get_domain(domain_expr) # define field domain with all dimensions in alphabetical order - sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) + sorted_domain_dims = sorted( + [Dimension(dim) for dim in domain.keys()], key=lambda x: x.value + ) # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type @@ -146,8 +148,8 @@ def build(self) -> list[SDFGField]: else: assert isinstance(arg_type, ts.FieldType) indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { - dim.value: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), + dim: gtir_to_tasklet.SymbolExpr( + dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim)), index_dtype, ) for dim in self.field_domain.keys() @@ -170,30 +172,34 @@ def build(self) -> list[SDFGField]: # the last transient node can be deleted # TODO: not needed to store the node `dtype` after type inference is in place - dtype = output_expr.node.desc(self.sdfg).dtype + output_desc = output_expr.node.desc(self.sdfg) self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field + field_dims = self.field_type.dims.copy() field_shape = [ # diff between upper and lower bound - self.field_domain[dim][1] - self.field_domain[dim][0] + self.field_domain[dim.value][1] - self.field_domain[dim.value][0] for dim in self.field_type.dims ] + if isinstance(output_desc, dace.data.Array): + # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) + field_dims.extend(Dimension(f"local_dim{i}") for i in range(len(output_desc.shape))) + field_shape.extend(output_desc.shape) + # TODO: use `self.field_type` without overriding `dtype` when type inference is in place - field_dtype = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - field_node = self.add_local_storage( - ts.FieldType(self.field_type.dims, field_dtype), field_shape - ) + field_dtype = dace_fieldview_util.as_scalar_type(str(output_desc.dtype.as_numpy_dtype())) + field_node = self.add_local_storage(ts.FieldType(field_dims, field_dtype), field_shape) # assume tasklet with single output - output_index = ",".join( - dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims - ) - output_memlet = dace.Memlet(data=field_node.data, subset=output_index) + output_subset = [dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims] + if isinstance(output_desc, dace.data.Array): + output_subset.extend(f"0:{size}" for size in output_desc.shape) + output_memlet = dace.Memlet(data=field_node.data, subset=",".join(output_subset)) # create map range corresponding to the field operator domain map_ranges = { - dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" + dimension_index_fmt.format(dim=dim): f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items() } me, mx = self.head_state.add_map("field_op", map_ranges) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index f3527eadd3..da2ed8a4b3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -233,7 +233,8 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) if isinstance(target_symbol_type, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims + f"{domain[dim.value][0]}:{domain[dim.value][1]}" + for dim in target_symbol_type.dims ) else: assert len(domain) == 0 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 4da021a6ad..19f0ab89c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -70,7 +70,7 @@ class IteratorExpr: InputConnection: TypeAlias = tuple[ dace.nodes.AccessNode, sbs.Range, - dace.nodes.Tasklet, + dace.nodes.NestedSDFG | dace.nodes.Tasklet, str, ] @@ -132,6 +132,56 @@ class IteratorExpr: } +def build_neighbors_sdfg( + field_dtype: dace.typeclass, + field_shape: tuple[int], + neighbors_shape: tuple[int], + index_dtype: dace.typeclass, +) -> tuple[dace.SDFG, str, str, str]: + assert len(field_shape) == len(neighbors_shape) + + sdfg = dace.SDFG("neighbors") + state = sdfg.add_state() + me, mx = state.add_map( + "neighbors", + {f"__idx_{i}": sbs.Range([(0, size - 1, 1)]) for i, size in enumerate(neighbors_shape)}, + ) + neighbor_index = ",".join(f"__idx_{i}" for i in range(len(neighbors_shape))) + + field_name, field_array = sdfg.add_array("field", field_shape, field_dtype) + index_name, _ = sdfg.add_array("indexes", neighbors_shape, index_dtype) + var_name, _ = sdfg.add_array("values", neighbors_shape, field_dtype) + tasklet_node = state.add_tasklet( + "gather_neighbors", + {"__field", "__index"}, + {"__val"}, + "__val = __field[__index]", + ) + state.add_memlet_path( + state.add_access(field_name), + me, + tasklet_node, + dst_conn="__field", + memlet=dace.Memlet.from_array(field_name, field_array), + ) + state.add_memlet_path( + state.add_access(index_name), + me, + tasklet_node, + dst_conn="__index", + memlet=dace.Memlet(data=index_name, subset=neighbor_index), + ) + state.add_memlet_path( + tasklet_node, + mx, + state.add_access(var_name), + src_conn="__val", + memlet=dace.Memlet(data=var_name, subset=neighbor_index), + ) + + return sdfg, field_name, index_name, var_name + + class LambdaToTasklet(eve.NodeVisitor): """Translates an `ir.Lambda` expression to a dataflow graph. @@ -168,18 +218,29 @@ def _add_input_connection( self.input_connections.append((src, subset, dst, dst_connector)) def _get_tasklet_result( - self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str + self, + dtype: dace.typeclass, + src_node: dace.nodes.NestedSDFG | dace.nodes.Tasklet, + src_connector: str, + shape: Optional[tuple[int,]] = None, ) -> ValueExpr: - scalar_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) - scalar_node = self.state.add_access(scalar_name) + if shape: + var_name, _ = self.sdfg.add_array( + "var", shape, dtype, transient=True, find_new_name=True + ) + subset = ",".join(f"0:{size}" for size in shape) + else: + var_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) + subset = "0" + var_node = self.state.add_access(var_name) self.state.add_edge( src_node, src_connector, - scalar_node, + var_node, None, - dace.Memlet(data=scalar_node.data, subset="0"), + dace.Memlet(data=var_node.data, subset=subset), ) - return ValueExpr(scalar_node) + return ValueExpr(var_node) def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 @@ -245,6 +306,82 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert isinstance(it, MemletExpr) return it + def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: + assert len(node.args) == 2 + + assert isinstance(node.args[0], itir.OffsetLiteral) + offset = node.args[0].value + assert isinstance(offset, str) + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, Connectivity) + + it = self.visit(node.args[1]) + assert isinstance(it, IteratorExpr) + assert offset_provider.neighbor_axis.value in it.dimensions + assert offset_provider.origin_axis.value in it.indices + origin_index = it.indices[offset_provider.origin_axis.value] + assert isinstance(origin_index, SymbolExpr) + assert offset_provider.origin_axis.value not in it.dimensions + assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) + + field_desc = it.field.desc(self.sdfg) + offset_table = dace_fieldview_util.connectivity_identifier(offset) + # initially, the storage for the connectivty tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # so the corresponding arrays are supposed to be allocated by the SDFG caller + self.sdfg.arrays[offset_table].transient = False + offset_table_node = self.state.add_access(offset_table) + + field_array_shape = tuple( + shape + for dim, shape in zip(it.dimensions, field_desc.shape, strict=True) + if dim == offset_provider.neighbor_axis.value + ) + assert len(field_array_shape) == 1 + + # we build a nested SDFG to gather all neighbors for each point in the field domain + # it can be seen as a library node + nsdfg, field_name, index_name, output_name = build_neighbors_sdfg( + field_desc.dtype, + field_array_shape, + (offset_provider.max_neighbors,), + self.sdfg.arrays[offset_table].dtype, + ) + + neighbors_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, {field_name, index_name}, {output_name} + ) + + self._add_input_connection( + it.field, + sbs.Range( + [ + (0, size - 1, 1) + if dim == offset_provider.neighbor_axis.value + else (it.indices[dim].value, it.indices[dim].value, 1) # type: ignore[union-attr] + for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + ] + ), + neighbors_node, + field_name, + ) + + self._add_input_connection( + offset_table_node, + sbs.Range( + [ + (origin_index.value, origin_index.value, 1), + (0, offset_provider.max_neighbors - 1, 1), + ] + ), + neighbors_node, + index_name, + ) + + return self._get_tasklet_result( + field_desc.dtype, neighbors_node, output_name, shape=(offset_provider.max_neighbors,) + ) + def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: @@ -461,12 +598,22 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "neighbors"): + return self._visit_neighbors(node) + elif cpm.is_call_to(node.fun, "shift"): return self._visit_shift(node) else: assert isinstance(node.fun, itir.SymRef) + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + node_internals = [] node_connections: dict[str, MemletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): @@ -481,13 +628,8 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value # use the argument value without adding any connector node_internals.append(arg_expr.value) - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") + # use tasklet connectors as expression arguments + code = fmt.format(*node_internals) out_connector = "result" tasklet_node = self.state.add_tasklet( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 8777ac3eda..f9919e230e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,7 @@ import dace -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Connectivity from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet @@ -69,7 +69,7 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne def get_domain( node: itir.Expr, -) -> dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: +) -> dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Specialized visit method for domain expressions. @@ -83,13 +83,12 @@ def get_domain( assert len(named_range.args) == 3 axis = named_range.args[0] assert isinstance(axis, itir.AxisLiteral) - dim = Dimension(axis.value) bounds = [] for arg in named_range.args[1:3]: sym_str = get_symbolic_expr(arg) sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) - domain[dim] = (bounds[0], bounds[1]) + domain[axis.value] = (bounds[0], bounds[1]) return domain diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 7c1cc5dd7f..442dbd269a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -27,6 +27,7 @@ Edge, IDim, MeshDescriptor, + V2EDim, Vertex, simple_mesh, ) @@ -803,3 +804,63 @@ def test_gtir_connectivity_shift_chain(): __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], ) assert np.allclose(e_out, ref) + + +def test_gtir_neighbors(): + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + v2e_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + im.call("named_range")( + itir.AxisLiteral(value=V2EDim.value), + 0, + SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + ), + ) + testee = itir.Program( + id=f"neighbors", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="v2e_field"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + domain=v2e_domain, + target=itir.SymRef(id="v2e_field"), + ) + ], + ) + + V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + + arg_types = [EFTYPE, V2E_FTYPE, SIZE_TYPE] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v2e_field = np.empty([SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors], dtype=e.dtype) + + sdfg( + edges=e, + v2e_field=v2e_field, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v2e_field, e[connectivity_V2E.table]) From 73008640bf5e8a4f314046370996cb065eebddc9 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 22 May 2024 11:43:43 +0200 Subject: [PATCH 073/235] Add support for reduce builtin --- .../gtir_builtin_translators.py | 138 ++++++++++++++---- .../runners/dace_fieldview/gtir_to_tasklet.py | 45 +++--- .../runners_tests/test_dace_fieldview.py | 64 ++++++++ 3 files changed, 192 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 46ece624a2..e24b043110 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -35,6 +35,8 @@ SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] +DIMENSION_INDEX_FMT = "i_{dim}" + @dataclass(frozen=True) class PrimitiveTranslator(abc.ABC): @@ -129,38 +131,73 @@ def __init__( self.stencil_expr = stencil_expr self.stencil_args = stencil_args - def build(self) -> list[SDFGField]: - dimension_index_fmt = "i_{dim}" - # type of variables used for field indexing - index_dtype = dace.int32 - # first visit the list of arguments and build a symbol map - stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in self.stencil_args: - arg_nodes = arg() - assert len(arg_nodes) == 1 - data_node, arg_type = arg_nodes[0] - # require all argument nodes to be data access nodes (no symbols) - assert isinstance(data_node, dace.nodes.AccessNode) + def build_reduce_node( + self, stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] + ) -> list[SDFGField]: + """Use dace library node for reduction.""" + node = self.stencil_expr.expr + assert isinstance(node, itir.FunCall) + + assert len(node.args) == 3 + op_name = node.args[0] + assert isinstance(op_name, itir.SymRef) + reduce_identity = node.args[1] + assert isinstance(reduce_identity, itir.Literal) + field_arg = node.args[2] + assert isinstance(field_arg, itir.FunCall) + + assert len(self.stencil_expr.params) == 1 + it_param = str(self.stencil_expr.params[0].id) + assert len(stencil_args) == 1 + it = stencil_args[0] + assert isinstance(it, gtir_to_tasklet.IteratorExpr) + + taskgen = gtir_to_tasklet.LambdaToTasklet( + self.sdfg, self.head_state, self.offset_provider, symbol_map={it_param: it} + ) + field_expr: gtir_to_tasklet.MemletExpr | gtir_to_tasklet.ValueExpr = taskgen.visit( + field_arg + ) - if isinstance(arg_type, ts.ScalarType): - scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - stencil_args.append(scalar_arg) - else: - assert isinstance(arg_type, ts.FieldType) - indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { - dim: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim)), - index_dtype, - ) - for dim in self.field_domain.keys() - } - iterator_arg = gtir_to_tasklet.IteratorExpr( - data_node, - [dim.value for dim in arg_type.dims], - indices, - ) - stencil_args.append(iterator_arg) + input_desc = field_expr.node.desc(self.sdfg) + assert isinstance(input_desc, dace.data.Array) + + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + + if len(input_desc.shape) > 1: + ndims = len(input_desc.shape) - 1 + assert ndims == len(it.dimensions) + reduce_axes = [ndims] + else: + ndims = 0 + reduce_axes = None + + reduce_wcr = "lambda x, y: " + gtir_to_tasklet.MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + reduce_node = self.head_state.add_reduce(reduce_wcr, reduce_axes, reduce_identity) + self.head_state.add_nedge( + field_expr.node, reduce_node, dace.Memlet.from_array(field_expr.node.data, input_desc) + ) + + field_type: ts.FieldType | ts.ScalarType + if ndims > 0: + field_type = ts.FieldType([Dimension(dim) for dim in it.dimensions], node_type) + output_node = self.add_local_storage(field_type, input_desc.shape[0:ndims]) + output_memlet = dace.Memlet.from_array(output_node.data, output_node.desc(self.sdfg)) + else: + field_type = node_type + output_node = self.add_local_storage(node_type, []) + output_memlet = dace.Memlet(data=output_node.data, subset="0") + + self.head_state.add_nedge(reduce_node, output_node, output_memlet) + return [(output_node, field_type)] + + def build_tasklet_node( + self, stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] + ) -> list[SDFGField]: # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) @@ -192,14 +229,14 @@ def build(self) -> list[SDFGField]: field_node = self.add_local_storage(ts.FieldType(field_dims, field_dtype), field_shape) # assume tasklet with single output - output_subset = [dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims] + output_subset = [DIMENSION_INDEX_FMT.format(dim=dim.value) for dim in self.field_type.dims] if isinstance(output_desc, dace.data.Array): output_subset.extend(f"0:{size}" for size in output_desc.shape) output_memlet = dace.Memlet(data=field_node.data, subset=",".join(output_subset)) # create map range corresponding to the field operator domain map_ranges = { - dimension_index_fmt.format(dim=dim): f"{lb}:{ub}" + DIMENSION_INDEX_FMT.format(dim=dim): f"{lb}:{ub}" for dim, (lb, ub) in self.field_domain.items() } me, mx = self.head_state.add_map("field_op", map_ranges) @@ -227,6 +264,43 @@ def build(self) -> list[SDFGField]: return [(field_node, self.field_type)] + def build(self) -> list[SDFGField]: + # type of variables used for field indexing + index_dtype = dace.int32 + # first visit the list of arguments and build a symbol map + stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] + for arg in self.stencil_args: + arg_nodes = arg() + assert len(arg_nodes) == 1 + data_node, arg_type = arg_nodes[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + stencil_args.append(scalar_arg) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim)), + index_dtype, + ) + for dim in self.field_domain.keys() + } + iterator_arg = gtir_to_tasklet.IteratorExpr( + data_node, + [dim.value for dim in arg_type.dims], + indices, + ) + stencil_args.append(iterator_arg) + + if cpm.is_call_to(self.stencil_expr.expr, "reduce"): + return self.build_reduce_node(stencil_args) + + else: + return self.build_tasklet_node(stencil_args) + class Select(PrimitiveTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 19f0ab89c9..e8bf33480e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -35,7 +35,7 @@ class MemletExpr: """Scalar or array data access thorugh a memlet.""" - source: dace.nodes.AccessNode + node: dace.nodes.AccessNode subset: sbs.Indices | sbs.Range @@ -70,7 +70,7 @@ class IteratorExpr: InputConnection: TypeAlias = tuple[ dace.nodes.AccessNode, sbs.Range, - dace.nodes.NestedSDFG | dace.nodes.Tasklet, + dace.nodes.Node, str, ] @@ -201,18 +201,19 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, offset_provider: dict[str, Connectivity | Dimension], + symbol_map: Optional[dict[str, IteratorExpr | MemletExpr | SymbolExpr]] = None, ): self.sdfg = sdfg self.state = state self.input_connections = [] self.offset_provider = offset_provider - self.symbol_map = {} + self.symbol_map = symbol_map if symbol_map is not None else {} def _add_input_connection( self, src: dace.nodes.AccessNode, subset: sbs.Range, - dst: dace.nodes.Tasklet, + dst: dace.nodes.Node, dst_connector: str, ) -> None: self.input_connections.append((src, subset, dst, dst_connector)) @@ -220,15 +221,14 @@ def _add_input_connection( def _get_tasklet_result( self, dtype: dace.typeclass, - src_node: dace.nodes.NestedSDFG | dace.nodes.Tasklet, - src_connector: str, - shape: Optional[tuple[int,]] = None, + src_node: dace.nodes.Node, + src_connector: Optional[str] = None, + subset: Optional[sbs.Range] = None, ) -> ValueExpr: - if shape: + if subset: var_name, _ = self.sdfg.add_array( - "var", shape, dtype, transient=True, find_new_name=True + "var", subset.size(), dtype, transient=True, find_new_name=True ) - subset = ",".join(f"0:{size}" for size in shape) else: var_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) subset = "0" @@ -282,7 +282,7 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) if isinstance(index_expr, MemletExpr): self._add_input_connection( - index_expr.source, + index_expr.node, index_expr.subset, deref_node, deref_connector, @@ -379,7 +379,10 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: ) return self._get_tasklet_result( - field_desc.dtype, neighbors_node, output_name, shape=(offset_provider.max_neighbors,) + field_desc.dtype, + neighbors_node, + output_name, + subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), ) def _split_shift_args( @@ -439,9 +442,9 @@ def _make_cartesian_shift( for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): if input_connector == "index": - dtype = input_expr.source.desc(self.sdfg).dtype + dtype = input_expr.node.desc(self.sdfg).dtype self._add_input_connection( - input_expr.source, + input_expr.node, input_expr.subset, dynamic_offset_tasklet, input_connector, @@ -500,7 +503,7 @@ def _make_dynamic_neighbor_offset( ) if isinstance(offset_expr, MemletExpr): self._add_input_connection( - offset_expr.source, + offset_expr.node, offset_expr.subset, tasklet_node, "offset", @@ -649,13 +652,11 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: - self._add_input_connection( - arg_expr.source, arg_expr.subset, tasklet_node, connector - ) + self._add_input_connection(arg_expr.node, arg_expr.subset, tasklet_node, connector) # TODO: use type inference to determine the result type if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): - dtype = node_connections["__inp_0"].source.desc(self.sdfg).dtype + dtype = node_connections["__inp_0"].node.desc(self.sdfg).dtype else: node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) dtype = dace_fieldview_util.as_dace_type(node_type) @@ -673,11 +674,9 @@ def visit_Lambda( if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.source.desc(self.sdfg).dtype + output_dtype = output_expr.node.desc(self.sdfg).dtype tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection( - output_expr.source, output_expr.subset, tasklet_node, "__inp" - ) + self._add_input_connection(output_expr.node, output_expr.subset, tasklet_node, "__inp") else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 442dbd269a..083dbf4ace 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -33,6 +33,7 @@ ) from next_tests.integration_tests.cases import EField, IFloatField, VField +from functools import reduce import numpy as np import pytest @@ -864,3 +865,66 @@ def test_gtir_neighbors(): __v2e_field_stride_1=1, ) assert np.allclose(v2e_field, e[connectivity_V2E.table]) + + +def test_gtir_reduce(): + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + testee = itir.Program( + id=f"neighbors_sum", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call("reduce")("plus", im.literal_from_value(0), im.deref("it")) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ), + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + + arg_types = [EFTYPE, VFTYPE, SIZE_TYPE] + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v_ref = [ + reduce(lambda x, y: x + y, e[v2e_neighbors], 0.0) + for v2e_neighbors in connectivity_V2E.table + ] + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v, v_ref) From 55adbd534625f329207dd42d1fc8c93f073f5bc6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 23 May 2024 08:09:59 +0200 Subject: [PATCH 074/235] Refactoring --- .../ir_utils/common_pattern_matcher.py | 20 +++++ .../next/iterator/transforms/fuse_maps.py | 15 +--- .../next/iterator/transforms/unroll_reduce.py | 17 +--- .../gtir_builtin_translators.py | 84 +++++-------------- .../runners/dace_fieldview/gtir_to_tasklet.py | 44 +++++++++- .../runners_tests/test_dace_fieldview.py | 4 +- 6 files changed, 92 insertions(+), 92 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 4933307c53..e01d6ea51f 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -27,6 +27,26 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `reduce(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "reduce" + ) + + +def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `shift(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "shift" + ) + + def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index c10cb6f3e7..ef7d6eabe6 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -18,6 +18,7 @@ from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.transforms import inline_lambdas @@ -29,14 +30,6 @@ def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]: ) -def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: - return ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="reduce") - ) - - @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ @@ -71,7 +64,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node) - if _is_map(node) or _is_reduce(node): + if _is_map(node) or cpm.is_applied_reduce(node): if any(_is_map(arg) for arg in node.args): first_param = ( 0 if _is_map(node) else 1 @@ -83,7 +76,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): inlined_args = [] new_params = [] new_args = [] - if _is_reduce(node): + if cpm.is_applied_reduce(node): # param corresponding to reduce acc inlined_args.append(ir.SymRef(id=outer_op.params[0].id)) new_params.append(outer_op.params[0]) @@ -119,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) - else: # _is_reduce(node) + else: # is_applied_reduce return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]), args=new_args, diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index b058fc0a7b..75cde58723 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,17 +20,10 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift -def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: - return ( - isinstance(arg, itir.FunCall) - and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="shift") - ) - - def _is_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return isinstance(arg, itir.FunCall) and arg.fun == itir.SymRef(id="neighbors") @@ -68,16 +61,12 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: return [_get_partial_offset_tag(arg) for arg in _get_neighbors_args(reduce_args)] -def _is_reduce(node: itir.FunCall) -> TypeGuard[itir.FunCall]: - return isinstance(node.fun, itir.FunCall) and node.fun.fun == itir.SymRef(id="reduce") - - def _get_connectivity( applied_reduce_node: itir.FunCall, offset_provider: dict[str, common.Dimension | common.Connectivity], ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): + if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] @@ -166,6 +155,6 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Expr: node = self.generic_visit(node, **kwargs) - if _is_reduce(node): + if cpm.is_applied_reduce(node): return self._visit_reduce(node, **kwargs) return node diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e24b043110..d73d368f73 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -132,76 +132,30 @@ def __init__( self.stencil_args = stencil_args def build_reduce_node( - self, stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] + self, + input_connections: list[gtir_to_tasklet.InputConnection], + output_expr: gtir_to_tasklet.ValueExpr, ) -> list[SDFGField]: """Use dace library node for reduction.""" - node = self.stencil_expr.expr - assert isinstance(node, itir.FunCall) - - assert len(node.args) == 3 - op_name = node.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.args[1] - assert isinstance(reduce_identity, itir.Literal) - field_arg = node.args[2] - assert isinstance(field_arg, itir.FunCall) - - assert len(self.stencil_expr.params) == 1 - it_param = str(self.stencil_expr.params[0].id) - assert len(stencil_args) == 1 - it = stencil_args[0] - assert isinstance(it, gtir_to_tasklet.IteratorExpr) - - taskgen = gtir_to_tasklet.LambdaToTasklet( - self.sdfg, self.head_state, self.offset_provider, symbol_map={it_param: it} - ) - field_expr: gtir_to_tasklet.MemletExpr | gtir_to_tasklet.ValueExpr = taskgen.visit( - field_arg - ) - - input_desc = field_expr.node.desc(self.sdfg) - assert isinstance(input_desc, dace.data.Array) - - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - if len(input_desc.shape) > 1: - ndims = len(input_desc.shape) - 1 - assert ndims == len(it.dimensions) - reduce_axes = [ndims] - else: - ndims = 0 - reduce_axes = None - - reduce_wcr = "lambda x, y: " + gtir_to_tasklet.MATH_BUILTINS_MAPPING[str(op_name)].format( - "x", "y" - ) - reduce_node = self.head_state.add_reduce(reduce_wcr, reduce_axes, reduce_identity) + assert len(input_connections) == 1 + source_node, source_subset, reduce_node, reduce_connector = input_connections[0] + assert reduce_connector is None self.head_state.add_nedge( - field_expr.node, reduce_node, dace.Memlet.from_array(field_expr.node.data, input_desc) + source_node, + reduce_node, + dace.Memlet(data=source_node.data, subset=source_subset), ) - field_type: ts.FieldType | ts.ScalarType - if ndims > 0: - field_type = ts.FieldType([Dimension(dim) for dim in it.dimensions], node_type) - output_node = self.add_local_storage(field_type, input_desc.shape[0:ndims]) - output_memlet = dace.Memlet.from_array(output_node.data, output_node.desc(self.sdfg)) - else: - field_type = node_type - output_node = self.add_local_storage(node_type, []) - output_memlet = dace.Memlet(data=output_node.data, subset="0") - - self.head_state.add_nedge(reduce_node, output_node, output_memlet) - return [(output_node, field_type)] + return [(output_expr.node, self.field_type)] def build_tasklet_node( - self, stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] + self, + input_connections: list[gtir_to_tasklet.InputConnection], + output_expr: gtir_to_tasklet.ValueExpr, ) -> list[SDFGField]: - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + """Translates the field operator to a mapped tasklet graph, which will range over the field domain.""" # retrieve the tasklet node which writes the result output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src @@ -295,11 +249,15 @@ def build(self) -> list[SDFGField]: ) stencil_args.append(iterator_arg) - if cpm.is_call_to(self.stencil_expr.expr, "reduce"): - return self.build_reduce_node(stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) + input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + + if cpm.is_applied_reduce(self.stencil_expr.expr): + return self.build_reduce_node(input_connections, output_expr) else: - return self.build_tasklet_node(stencil_args) + return self.build_tasklet_node(input_connections, output_expr) class Select(PrimitiveTranslator): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index e8bf33480e..7f53e1abe3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -71,7 +71,7 @@ class IteratorExpr: dace.nodes.AccessNode, sbs.Range, dace.nodes.Node, - str, + Optional[str], ] INDEX_CONNECTOR_FMT = "__index_{dim}" @@ -214,7 +214,7 @@ def _add_input_connection( src: dace.nodes.AccessNode, subset: sbs.Range, dst: dace.nodes.Node, - dst_connector: str, + dst_connector: Optional[str] = None, ) -> None: self.input_connections.append((src, subset, dst, dst_connector)) @@ -385,6 +385,41 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), ) + def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: + assert isinstance(node.fun, itir.FunCall) + assert len(node.fun.args) == 2 + op_name = node.fun.args[0] + assert isinstance(op_name, itir.SymRef) + reduce_identity = node.fun.args[1] + assert isinstance(reduce_identity, itir.Literal) + + assert len(node.args) == 1 + input_expr = self.visit(node.args[0]) + assert isinstance(input_expr, MemletExpr | ValueExpr) + input_desc = input_expr.node.desc(self.sdfg) + + assert isinstance(input_desc, dace.data.Array) + if len(input_desc.shape) > 1: + ndims = len(input_desc.shape) - 1 + reduce_axes = [ndims] + else: + ndims = 0 + reduce_axes = None + + reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") + reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_identity) + + input_subset = sbs.Range([(0, dim_size - 1, 1) for dim_size in input_desc.shape]) + if ndims > 0: + result_subset = sbs.Range(input_subset[0:ndims]) + else: + result_subset = None + + self._add_input_connection(input_expr.node, input_subset, reduce_node) + + # TODO: use type inference to determine the result type + return self._get_tasklet_result(input_desc.dtype, reduce_node, None, result_subset) + def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: @@ -604,7 +639,10 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) - elif cpm.is_call_to(node.fun, "shift"): + elif cpm.is_applied_reduce(node): + return self._visit_reduce(node) + + elif cpm.is_applied_shift(node.fun): return self._visit_shift(node) else: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 083dbf4ace..b84e08e8a7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -885,7 +885,9 @@ def test_gtir_reduce(): expr=im.call( im.call("as_fieldop")( im.lambda_("it")( - im.call("reduce")("plus", im.literal_from_value(0), im.deref("it")) + im.call(im.call("reduce")("plus", im.literal_from_value(0)))( + im.deref("it") + ) ), vertex_domain, ) From ad21dc4b46becac7017e778fff07a012b758001c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 23 May 2024 08:35:25 +0200 Subject: [PATCH 075/235] Add support for both inlined and fieldview neighbor reduction --- .../gtir_builtin_translators.py | 4 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 9 +- .../runners_tests/test_dace_fieldview.py | 104 ++++++++++-------- 3 files changed, 69 insertions(+), 48 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index d73d368f73..67b8e5ef1e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -253,7 +253,9 @@ def build(self) -> list[SDFGField]: input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - if cpm.is_applied_reduce(self.stencil_expr.expr): + if cpm.is_applied_reduce(self.stencil_expr.expr) and cpm.is_call_to( + self.stencil_expr.expr.args[0], "deref" + ): return self.build_reduce_node(input_connections, output_expr) else: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 7f53e1abe3..d996d19b12 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -415,7 +415,14 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: else: result_subset = None - self._add_input_connection(input_expr.node, input_subset, reduce_node) + if isinstance(input_expr, MemletExpr): + self._add_input_connection(input_expr.node, input_subset, reduce_node) + else: + self.state.add_nedge( + input_expr.node, + reduce_node, + dace.Memlet(data=input_expr.node.data, subset=input_subset), + ) # TODO: use type inference to determine the result type return self._get_tasklet_result(input_desc.dtype, reduce_node, None, result_subset) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index b84e08e8a7..d0ca207b63 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -871,62 +871,74 @@ def test_gtir_reduce(): vertex_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), ) - testee = itir.Program( - id=f"neighbors_sum", - function_definitions=[], - params=[ - itir.Sym(id="edges"), - itir.Sym(id="vertices"), - itir.Sym(id="nvertices"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(0)))( - im.deref("it") - ) - ), - vertex_domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges") - ), - domain=vertex_domain, - target=itir.SymRef(id="vertices"), + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(0)))( + im.neighbors("V2E", "it") + ) + ), + vertex_domain, + ) + )("edges") + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(0)))(im.deref("it")) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, ) - ], + )("edges") ) arg_types = [EFTYPE, VFTYPE, SIZE_TYPE] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) - connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) v_ref = [ reduce(lambda x, y: x + y, e[v2e_neighbors], 0.0) for v2e_neighbors in connectivity_V2E.table ] - sdfg( - edges=e, - vertices=v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **CSYMBOLS, - __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, - __v2e_field_stride_1=1, - ) - assert np.allclose(v, v_ref) + for stencil in [stencil_inlined, stencil_fieldview]: + testee = itir.Program( + id=f"neighbors_sum", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + + # new empty output field + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **CSYMBOLS, + __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v, v_ref) From bb9123b2493a5252789a0649edb3ee638c209bd0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 23 May 2024 08:56:19 +0200 Subject: [PATCH 076/235] Minor edit --- .../runners/dace_fieldview/gtir_builtin_translators.py | 4 +--- .../runners/dace_fieldview/gtir_to_tasklet.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 67b8e5ef1e..25aace306b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -118,9 +118,7 @@ def __init__( domain = dace_fieldview_util.get_domain(domain_expr) # define field domain with all dimensions in alphabetical order - sorted_domain_dims = sorted( - [Dimension(dim) for dim in domain.keys()], key=lambda x: x.value - ) + sorted_domain_dims = [Dimension(dim) for dim in sorted(domain.keys())] # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index d996d19b12..141d9a3427 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -201,13 +201,12 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, offset_provider: dict[str, Connectivity | Dimension], - symbol_map: Optional[dict[str, IteratorExpr | MemletExpr | SymbolExpr]] = None, ): self.sdfg = sdfg self.state = state self.input_connections = [] self.offset_provider = offset_provider - self.symbol_map = symbol_map if symbol_map is not None else {} + self.symbol_map = {} def _add_input_connection( self, From 0025d773b4d4d343db5b7fc8c40cb18ab953e60f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 23 May 2024 09:31:33 +0200 Subject: [PATCH 077/235] Code refactoring --- .../runners/dace_fieldview/gtir_to_tasklet.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 141d9a3427..8856daca5a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -398,22 +398,19 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: input_desc = input_expr.node.desc(self.sdfg) assert isinstance(input_desc, dace.data.Array) + input_subset = sbs.Range([(0, dim_size - 1, 1) for dim_size in input_desc.shape]) + if len(input_desc.shape) > 1: ndims = len(input_desc.shape) - 1 reduce_axes = [ndims] + result_subset = sbs.Range(input_subset[0:ndims]) else: - ndims = 0 reduce_axes = None + result_subset = None reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_identity) - input_subset = sbs.Range([(0, dim_size - 1, 1) for dim_size in input_desc.shape]) - if ndims > 0: - result_subset = sbs.Range(input_subset[0:ndims]) - else: - result_subset = None - if isinstance(input_expr, MemletExpr): self._add_input_connection(input_expr.node, input_subset, reduce_node) else: @@ -648,7 +645,7 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value elif cpm.is_applied_reduce(node): return self._visit_reduce(node) - elif cpm.is_applied_shift(node.fun): + elif cpm.is_applied_shift(node): return self._visit_shift(node) else: From 9926d7d95ff0515fdb34910534dbe9a3e9785b44 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 23 May 2024 11:55:32 +0200 Subject: [PATCH 078/235] Add support for skip values ONLY for inlined GTIR --- .../runners/dace_fieldview/gtir_to_tasklet.py | 28 +++- .../runners_tests/test_dace_fieldview.py | 148 +++++++++++++----- 2 files changed, 135 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 8856daca5a..b8a2fbf9e3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -24,7 +24,7 @@ from gt4py import eve from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util @@ -137,6 +137,7 @@ def build_neighbors_sdfg( field_shape: tuple[int], neighbors_shape: tuple[int], index_dtype: dace.typeclass, + reduce_identity: Optional[SymbolExpr] = None, ) -> tuple[dace.SDFG, str, str, str]: assert len(field_shape) == len(neighbors_shape) @@ -151,11 +152,17 @@ def build_neighbors_sdfg( field_name, field_array = sdfg.add_array("field", field_shape, field_dtype) index_name, _ = sdfg.add_array("indexes", neighbors_shape, index_dtype) var_name, _ = sdfg.add_array("values", neighbors_shape, field_dtype) + + if reduce_identity is not None: + typed_identity_value = f"{reduce_identity.dtype}({reduce_identity.value})" + skip_value_code = f" if __index != {neighbor_skip_value} else {typed_identity_value}" + else: + skip_value_code = "" tasklet_node = state.add_tasklet( "gather_neighbors", {"__field", "__index"}, {"__val"}, - "__val = __field[__index]", + "__val = __field[__index]" + skip_value_code, ) state.add_memlet_path( state.add_access(field_name), @@ -195,18 +202,21 @@ class LambdaToTasklet(eve.NodeVisitor): input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] + reduce_identity: Optional[SymbolExpr] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, offset_provider: dict[str, Connectivity | Dimension], + reduce_identity: Optional[SymbolExpr] = None, ): self.sdfg = sdfg self.state = state self.input_connections = [] self.offset_provider = offset_provider self.symbol_map = {} + self.reduce_identity = reduce_identity def _add_input_connection( self, @@ -313,6 +323,8 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: assert isinstance(offset, str) offset_provider = self.offset_provider[offset] assert isinstance(offset_provider, Connectivity) + if offset_provider.has_skip_values: + assert self.reduce_identity is not None it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) @@ -345,6 +357,7 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: field_array_shape, (offset_provider.max_neighbors,), self.sdfg.arrays[offset_table].dtype, + self.reduce_identity, ) neighbors_node = self.state.add_nested_sdfg( @@ -392,9 +405,17 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: reduce_identity = node.fun.args[1] assert isinstance(reduce_identity, itir.Literal) + # TODO: use type inference to determine the result type + dtype = dace.float64 + + # we store the value of reduce identity in the visitor context while visiting the input to reduction + # this value will be returned by the neighbors builtin function for skip values + prev_reduce_identity = self.reduce_identity + self.reduce_identity = SymbolExpr(reduce_identity.value, dtype) assert len(node.args) == 1 input_expr = self.visit(node.args[0]) assert isinstance(input_expr, MemletExpr | ValueExpr) + self.reduce_identity = prev_reduce_identity input_desc = input_expr.node.desc(self.sdfg) assert isinstance(input_desc, dace.data.Array) @@ -420,8 +441,7 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: dace.Memlet(data=input_expr.node.data, subset=input_subset), ) - # TODO: use type inference to determine the result type - return self._get_tasklet_result(input_desc.dtype, reduce_node, None, result_subset) + return self._get_tasklet_result(dtype, reduce_node, None, result_subset) def _split_shift_args( self, args: list[itir.Expr] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index d0ca207b63..0e06944d58 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -30,6 +30,7 @@ V2EDim, Vertex, simple_mesh, + skip_value_mesh, ) from next_tests.integration_tests.cases import EField, IFloatField, VField @@ -46,6 +47,7 @@ EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) SIMPLE_MESH: MeshDescriptor = simple_mesh() +SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -58,33 +60,36 @@ __z_stride_0=1, size=N, ) -CSYMBOLS = dict( - ncells=SIMPLE_MESH.num_cells, - nedges=SIMPLE_MESH.num_edges, - nvertices=SIMPLE_MESH.num_vertices, - __cells_size_0=SIMPLE_MESH.num_cells, - __cells_stride_0=1, - __edges_size_0=SIMPLE_MESH.num_edges, - __edges_stride_0=1, - __vertices_size_0=SIMPLE_MESH.num_vertices, - __vertices_stride_0=1, - __connectivity_C2E_size_0=SIMPLE_MESH.num_cells, - __connectivity_C2E_size_1=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=SIMPLE_MESH.num_cells, - __connectivity_C2V_size_1=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=SIMPLE_MESH.num_edges, - __connectivity_E2V_size_1=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, - __connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_1=1, -) + + +def make_mesh_symbols(mesh: MeshDescriptor): + return dict( + ncells=mesh.num_cells, + nedges=mesh.num_edges, + nvertices=mesh.num_vertices, + __cells_size_0=mesh.num_cells, + __cells_stride_0=1, + __edges_size_0=mesh.num_edges, + __edges_stride_0=1, + __vertices_size_0=mesh.num_vertices, + __vertices_stride_0=1, + __connectivity_C2E_size_0=mesh.num_cells, + __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_1=1, + __connectivity_C2V_size_0=mesh.num_cells, + __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_1=1, + __connectivity_E2V_size_0=mesh.num_edges, + __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_1=1, + __connectivity_V2E_size_0=mesh.num_vertices, + __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_1=1, + ) def test_gtir_copy(): @@ -729,7 +734,7 @@ def test_gtir_connectivity_shift(): connectivity_C2E=connectivity_C2E.table, connectivity_C2V=connectivity_C2V.table, **FSYMBOLS, - **CSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), __ve_field_size_0=SIMPLE_MESH.num_vertices, __ve_field_size_1=SIMPLE_MESH.num_edges, __ve_field_stride_0=SIMPLE_MESH.num_edges, @@ -800,9 +805,9 @@ def test_gtir_connectivity_shift_chain(): connectivity_E2V=connectivity_E2V.table, connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **CSYMBOLS, - __edges_out_size_0=CSYMBOLS["__edges_size_0"], - __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], + **make_mesh_symbols(SIMPLE_MESH), + __edges_out_size_0=SIMPLE_MESH.num_edges, + __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) @@ -858,7 +863,7 @@ def test_gtir_neighbors(): v2e_field=v2e_field, connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **CSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.max_neighbors, __v2e_field_stride_0=connectivity_V2E.max_neighbors, @@ -935,10 +940,79 @@ def test_gtir_reduce(): vertices=v, connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **CSYMBOLS, - __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, - __v2e_field_stride_1=1, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) + + +def test_gtir_reduce_with_skip_values(): + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(0)))( + im.neighbors("V2E", "it") + ) + ), + vertex_domain, + ) + )("edges") + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(0)))(im.deref("it")) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ) + + arg_types = [EFTYPE, VFTYPE, SIZE_TYPE] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + e = np.random.rand(SKIP_VALUE_MESH.num_edges) + v_ref = [ + reduce(lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], 0.0) + for v2e_neighbors in connectivity_V2E.table + ] + + for stencil in [stencil_inlined]: + testee = itir.Program( + id=f"neighbors_sum", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SKIP_VALUE_MESH.offset_provider) + + # new empty output field + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SKIP_VALUE_MESH), ) assert np.allclose(v, v_ref) From 172f19e77aab6de146c9d3132749bdda97411829 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 27 May 2024 09:48:28 +0200 Subject: [PATCH 079/235] Masked array implementation based on connectivity table --- .../gtir_builtin_translators.py | 145 ++++---- .../dace_fieldview/gtir_dace_backend.py | 4 + .../runners/dace_fieldview/gtir_to_sdfg.py | 3 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 323 +++++++++++++++--- .../runners_tests/test_dace_fieldview.py | 2 +- 5 files changed, 345 insertions(+), 132 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 25aace306b..b40806ae61 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -15,7 +15,7 @@ import abc from dataclasses import dataclass -from typing import Callable, TypeAlias, final +from typing import Callable, Optional, TypeAlias, final import dace import dace.subsets as sbs @@ -33,7 +33,7 @@ # Define aliases for return types SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] -SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] +SDFGFieldBuilder: TypeAlias = Callable[[], tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]] DIMENSION_INDEX_FMT = "i_{dim}" @@ -44,7 +44,7 @@ class PrimitiveTranslator(abc.ABC): head_state: dace.SDFGState @final - def __call__(self) -> list[SDFGField]: + def __call__(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: """The callable interface is used to build the dataflow graph. It allows to build the dataflow graph inside a given state starting @@ -73,7 +73,7 @@ def add_local_storage( return self.head_state.add_access(name) @abc.abstractmethod - def build(self) -> list[SDFGField]: + def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: """Creates the dataflow subgraph representing a GTIR builtin function. This method is used by derived classes to build a specialized subgraph @@ -129,31 +129,41 @@ def __init__( self.stencil_expr = stencil_expr self.stencil_args = stencil_args - def build_reduce_node( - self, - input_connections: list[gtir_to_tasklet.InputConnection], - output_expr: gtir_to_tasklet.ValueExpr, - ) -> list[SDFGField]: - """Use dace library node for reduction.""" - - assert len(input_connections) == 1 - source_node, source_subset, reduce_node, reduce_connector = input_connections[0] - assert reduce_connector is None - - self.head_state.add_nedge( - source_node, - reduce_node, - dace.Memlet(data=source_node.data, subset=source_subset), - ) + def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + # type of variables used for field indexing + index_dtype = dace.int32 + # first visit the list of arguments and build a symbol map + stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] + for arg in self.stencil_args: + arg_nodes, mask_node = arg() + assert len(arg_nodes) == 1 + data_node, arg_type = arg_nodes[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) - return [(output_expr.node, self.field_type)] + if isinstance(arg_type, ts.ScalarType): + scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + stencil_args.append(scalar_arg) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim)), + index_dtype, + ) + for dim in self.field_domain.keys() + } + iterator_arg = gtir_to_tasklet.IteratorExpr( + data_node, + mask_node, + [dim.value for dim in arg_type.dims], + indices, + ) + stencil_args.append(iterator_arg) - def build_tasklet_node( - self, - input_connections: list[gtir_to_tasklet.InputConnection], - output_expr: gtir_to_tasklet.ValueExpr, - ) -> list[SDFGField]: - """Translates the field operator to a mapped tasklet graph, which will range over the field domain.""" + taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) + input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) # retrieve the tasklet node which writes the result output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src @@ -176,15 +186,15 @@ def build_tasklet_node( field_dims.extend(Dimension(f"local_dim{i}") for i in range(len(output_desc.shape))) field_shape.extend(output_desc.shape) - # TODO: use `self.field_type` without overriding `dtype` when type inference is in place + # TODO: use `self.field_type.field_dtype` without overriding `dtype` when type inference is in place field_dtype = dace_fieldview_util.as_scalar_type(str(output_desc.dtype.as_numpy_dtype())) - field_node = self.add_local_storage(ts.FieldType(field_dims, field_dtype), field_shape) + field_type = ts.FieldType(field_dims, field_dtype) + field_node = self.add_local_storage(field_type, field_shape) # assume tasklet with single output output_subset = [DIMENSION_INDEX_FMT.format(dim=dim.value) for dim in self.field_type.dims] if isinstance(output_desc, dace.data.Array): output_subset.extend(f"0:{size}" for size in output_desc.shape) - output_memlet = dace.Memlet(data=field_node.data, subset=",".join(output_subset)) # create map range corresponding to the field operator domain map_ranges = { @@ -211,53 +221,26 @@ def build_tasklet_node( mx, field_node, src_conn=output_tasklet_connector, - memlet=output_memlet, + memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), ) - return [(field_node, self.field_type)] - - def build(self) -> list[SDFGField]: - # type of variables used for field indexing - index_dtype = dace.int32 - # first visit the list of arguments and build a symbol map - stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in self.stencil_args: - arg_nodes = arg() - assert len(arg_nodes) == 1 - data_node, arg_type = arg_nodes[0] - # require all argument nodes to be data access nodes (no symbols) - assert isinstance(data_node, dace.nodes.AccessNode) - - if isinstance(arg_type, ts.ScalarType): - scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - stencil_args.append(scalar_arg) - else: - assert isinstance(arg_type, ts.FieldType) - indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { - dim: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim)), - index_dtype, - ) - for dim in self.field_domain.keys() - } - iterator_arg = gtir_to_tasklet.IteratorExpr( - data_node, - [dim.value for dim in arg_type.dims], - indices, - ) - stencil_args.append(iterator_arg) - - taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - - if cpm.is_applied_reduce(self.stencil_expr.expr) and cpm.is_call_to( - self.stencil_expr.expr.args[0], "deref" - ): - return self.build_reduce_node(input_connections, output_expr) + if isinstance(output_expr, gtir_to_tasklet.MaskedValueExpr): + # this is the case of neighbors with skip values: the value expression also contains the neighbor indices + mask_numpy_dtype = self.sdfg.arrays[output_expr.mask.data].dtype.as_numpy_dtype() + mask_dtype = dace_fieldview_util.as_scalar_type(str(mask_numpy_dtype)) + mask_node = self.add_local_storage(ts.FieldType(field_dims, mask_dtype), field_shape) + + self.head_state.add_memlet_path( + output_expr.mask, + mx, + mask_node, + memlet=dace.Memlet(data=mask_node.data, subset=",".join(output_subset)), + ) else: - return self.build_tasklet_node(input_connections, output_expr) + mask_node = None + + return [(field_node, self.field_type)], mask_node class Select(PrimitiveTranslator): @@ -314,7 +297,7 @@ def __init__( false_expr, sdfg=sdfg, head_state=false_state ) - def build(self) -> list[SDFGField]: + def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: # retrieve true/false states as predecessors of head state branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) assert len(branch_states) == 2 @@ -323,8 +306,8 @@ def build(self) -> list[SDFGField]: else: false_state, true_state = branch_states - true_br_args = self.true_br_builder() - false_br_args = self.false_br_builder() + true_br_args, true_br_mask = self.true_br_builder() + false_br_args, false_br_mask = self.false_br_builder() output_nodes = [] for true_br, false_br in zip(true_br_args, false_br_args, strict=True): @@ -355,7 +338,11 @@ def build(self) -> list[SDFGField]: false_br_output_node.data, false_br_output_node.desc(self.sdfg) ), ) - return output_nodes + + # TODO: add support for masked array values in select statements, if this lowering path is needed + assert not (true_br_mask or false_br_mask) + + return output_nodes, None class SymbolRef(PrimitiveTranslator): @@ -375,7 +362,7 @@ def __init__( self.sym_name = sym_name self.sym_type = sym_type - def build(self) -> list[SDFGField]: + def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: if isinstance(self.sym_type, ts.FieldType): # add access node to current state sym_node = self.head_state.add_access(self.sym_name) @@ -399,4 +386,4 @@ def build(self) -> list[SDFGField]: dace.Memlet(data=sym_node.data, subset="0"), ) - return [(sym_node, self.sym_type)] + return [(sym_node, self.sym_type)], None diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py index c8c798292a..bcbe390aca 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dace +from dace.sdfg import utils as sdutils from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir @@ -31,5 +32,8 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + sdfg.simplify() return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index da2ed8a4b3..f875d4d6a8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -133,7 +133,8 @@ def _visit_expression( field_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( node, sdfg=sdfg, head_state=head_state ) - results = field_builder() + results, mask_node = field_builder() + assert mask_node is None field_nodes = [] for node, _ in results: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index b8a2fbf9e3..13c46d1729 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -62,6 +62,7 @@ class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode + mask: Optional[dace.nodes.AccessNode] dimensions: list[str] indices: dict[str, IteratorIndexExpr] @@ -74,6 +75,21 @@ class IteratorExpr: Optional[str], ] + +@dataclass(frozen=True) +class MaskedMemletExpr(MemletExpr): + """Scalar or array data access thorugh a memlet.""" + + mask: dace.nodes.AccessNode + + +@dataclass(frozen=True) +class MaskedValueExpr(ValueExpr): + """Result of the computation implemented by a tasklet node.""" + + mask: dace.nodes.AccessNode + + INDEX_CONNECTOR_FMT = "__index_{dim}" @@ -137,7 +153,7 @@ def build_neighbors_sdfg( field_shape: tuple[int], neighbors_shape: tuple[int], index_dtype: dace.typeclass, - reduce_identity: Optional[SymbolExpr] = None, + with_skip_values: bool, ) -> tuple[dace.SDFG, str, str, str]: assert len(field_shape) == len(neighbors_shape) @@ -152,10 +168,10 @@ def build_neighbors_sdfg( field_name, field_array = sdfg.add_array("field", field_shape, field_dtype) index_name, _ = sdfg.add_array("indexes", neighbors_shape, index_dtype) var_name, _ = sdfg.add_array("values", neighbors_shape, field_dtype) + index_node = state.add_access(index_name) - if reduce_identity is not None: - typed_identity_value = f"{reduce_identity.dtype}({reduce_identity.value})" - skip_value_code = f" if __index != {neighbor_skip_value} else {typed_identity_value}" + if with_skip_values: + skip_value_code = f" if __index != {neighbor_skip_value} else {field_dtype}(0)" else: skip_value_code = "" tasklet_node = state.add_tasklet( @@ -172,7 +188,7 @@ def build_neighbors_sdfg( memlet=dace.Memlet.from_array(field_name, field_array), ) state.add_memlet_path( - state.add_access(index_name), + index_node, me, tasklet_node, dst_conn="__index", @@ -189,6 +205,109 @@ def build_neighbors_sdfg( return sdfg, field_name, index_name, var_name +def build_reduce_sdfg( + op_name: itir.SymRef, + init_value: itir.Literal, + result_dtype: dace.typeclass, + values_desc: dace.data.Array, + indices_desc: Optional[dace.data.Array] = None, +) -> tuple[dace.SDFG, str, Optional[str], str]: + sdfg = dace.SDFG("reduce") + + neighbors_len = values_desc.shape[-1] + + input_var, input_desc = sdfg.add_array("values", (neighbors_len,), values_desc.dtype) + acc_var, _ = sdfg.add_scalar("var", result_dtype) + + if indices_desc: + assert values_desc.shape == indices_desc.shape + indices_var, _ = sdfg.add_array("indices", (neighbors_len,), indices_desc.dtype) + + neighbor_idx = "__idx" + reduce_loop = dace.sdfg.state.LoopRegion( + label="reduce", + loop_var=neighbor_idx, + initialize_expr=f"{neighbor_idx} = 0", + condition_expr=f"{neighbor_idx} < {neighbors_len}", + update_expr=f"{neighbor_idx} = {neighbor_idx} + 1", + inverted=False, + ) + sdfg.add_node(reduce_loop) + reduce_state = reduce_loop.add_state("loop") + acc_code = MATH_BUILTINS_MAPPING[str(op_name)].format("acc", "val") + reduce_tasklet = reduce_state.add_tasklet( + "reduce_with_skip_values", + {"acc", "val", "idx"}, + {"res"}, + f"res = {acc_code} if idx != {neighbor_skip_value} else acc", + ) + reduce_state.add_edge( + reduce_state.add_access(acc_var), + None, + reduce_tasklet, + "acc", + dace.Memlet(data=acc_var, subset="0"), + ) + reduce_state.add_edge( + reduce_state.add_access(input_var), + None, + reduce_tasklet, + "val", + dace.Memlet(data=input_var, subset=neighbor_idx), + ) + reduce_state.add_edge( + reduce_state.add_access(indices_var), + None, + reduce_tasklet, + "idx", + dace.Memlet(data=indices_var, subset=neighbor_idx), + ) + reduce_state.add_edge( + reduce_tasklet, + "res", + reduce_state.add_access(acc_var), + None, + dace.Memlet(data=acc_var, subset="0"), + ) + + init_state = sdfg.add_state("init", is_start_block=True) + init_tasklet = init_state.add_tasklet( + "init_reduce", + {}, + {"val"}, + f"val = {init_value}", + ) + init_state.add_edge( + init_tasklet, + "val", + init_state.add_access(acc_var), + None, + dace.Memlet(data=acc_var, subset="0"), + ) + sdfg.add_edge(init_state, reduce_loop, dace.InterstateEdge()) + + else: + state = sdfg.add_state("main") + + reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") + reduce_node = state.add_reduce(reduce_wcr, None, init_value) + + state.add_nedge( + state.add_access(input_var), + reduce_node, + dace.Memlet.from_array(input_var, input_desc), + ) + state.add_nedge( + reduce_node, + state.add_access(acc_var), + dace.Memlet(data=acc_var, subset="0"), + ) + + indices_var = None + + return sdfg, input_var, indices_var, acc_var + + class LambdaToTasklet(eve.NodeVisitor): """Translates an `ir.Lambda` expression to a dataflow graph. @@ -202,21 +321,18 @@ class LambdaToTasklet(eve.NodeVisitor): input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - reduce_identity: Optional[SymbolExpr] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, offset_provider: dict[str, Connectivity | Dimension], - reduce_identity: Optional[SymbolExpr] = None, ): self.sdfg = sdfg self.state = state self.input_connections = [] self.offset_provider = offset_provider self.symbol_map = {} - self.reduce_identity = reduce_identity def _add_input_connection( self, @@ -259,9 +375,16 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] - return MemletExpr(it.field, data_index) + return ( + MemletExpr(it.field, data_index) + if it.mask is None + else MaskedMemletExpr(it.field, data_index, it.mask) + ) else: + # masked array not supported with indirect field access + assert it.mask is None + # we use a tasklet to perform dereferencing of a generic iterator assert all(dim in it.indices.keys() for dim in it.dimensions) field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] @@ -323,8 +446,6 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: assert isinstance(offset, str) offset_provider = self.offset_provider[offset] assert isinstance(offset_provider, Connectivity) - if offset_provider.has_skip_values: - assert self.reduce_identity is not None it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) @@ -336,12 +457,13 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) field_desc = it.field.desc(self.sdfg) - offset_table = dace_fieldview_util.connectivity_identifier(offset) + connectivity = dace_fieldview_util.connectivity_identifier(offset) # initially, the storage for the connectivty tables is created as transient; # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller - self.sdfg.arrays[offset_table].transient = False - offset_table_node = self.state.add_access(offset_table) + connectivity_desc = self.sdfg.arrays[connectivity] + connectivity_desc.transient = False + connectivity_node = self.state.add_access(connectivity) field_array_shape = tuple( shape @@ -356,8 +478,8 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: field_desc.dtype, field_array_shape, (offset_provider.max_neighbors,), - self.sdfg.arrays[offset_table].dtype, - self.reduce_identity, + connectivity_desc.dtype, + offset_provider.has_skip_values, ) neighbors_node = self.state.add_nested_sdfg( @@ -379,7 +501,7 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: ) self._add_input_connection( - offset_table_node, + connectivity_node, sbs.Range( [ (origin_index.value, origin_index.value, 1), @@ -390,58 +512,154 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: index_name, ) - return self._get_tasklet_result( - field_desc.dtype, - neighbors_node, - output_name, - subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), + if offset_provider.has_skip_values: + # simulate pattern of masked array, using the connctivity table as a mask + neighbor_val_name, neighbor_val_array = self.sdfg.add_array( + "neighbor_val", + (offset_provider.max_neighbors,), + field_desc.dtype, + transient=True, + find_new_name=True, + ) + neighbor_idx_name, neighbor_idx_array = self.sdfg.add_array( + "neighbor_idx", + (offset_provider.max_neighbors,), + connectivity_desc.dtype, + transient=True, + find_new_name=True, + ) + + neighbor_val_node = self.state.add_access(neighbor_val_name) + self.state.add_edge( + neighbors_node, + output_name, + neighbor_val_node, + None, + dace.Memlet(data=neighbor_val_name, subset=f"0:{offset_provider.max_neighbors}"), + ) + + neighbor_idx_node = self.state.add_access(neighbor_idx_name) + self._add_input_connection( + connectivity_node, + sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + neighbor_idx_node, + ) + + return MaskedValueExpr(neighbor_val_node, neighbor_idx_node) + + else: + return self._get_tasklet_result( + field_desc.dtype, + neighbors_node, + output_name, + subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), + ) + + def _make_reduce_with_skip_values( + self, + op_name: itir.SymRef, + init_value: itir.Literal, + result_dtype: dace.typeclass, + reduce_expr: MaskedMemletExpr | MaskedValueExpr, + ) -> ValueExpr: + values_desc = reduce_expr.node.desc(self.sdfg) + indices_desc = reduce_expr.mask.desc(self.sdfg) + assert indices_desc.shape == values_desc.shape + + nsdfg, field_name, index_name, output_name = build_reduce_sdfg( + op_name, init_value, result_dtype, values_desc, indices_desc ) + reduce_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, {field_name, index_name}, {output_name} + ) + + if isinstance(reduce_expr, MaskedMemletExpr): + assert isinstance(reduce_expr.subset, sbs.Indices) + ndims = len(reduce_expr.subset) + assert len(values_desc.shape) == ndims + 1 + local_size = values_desc.shape[ndims] + input_subset = sbs.Range.from_indices(reduce_expr.subset) + sbs.Range.from_string( + f"0:{local_size}" + ) + self._add_input_connection( + reduce_expr.node, + input_subset, + reduce_node, + field_name, + ) + self._add_input_connection( + reduce_expr.mask, + input_subset, + reduce_node, + index_name, + ) + else: + self.state.add_edge( + reduce_expr.node, + None, + reduce_node, + field_name, + dace.Memlet.from_array(reduce_expr.node.data, values_desc), + ) + self.state.add_edge( + reduce_expr.mask, + None, + reduce_node, + index_name, + dace.Memlet.from_array(reduce_expr.mask.data, indices_desc), + ) + + return self._get_tasklet_result(result_dtype, reduce_node, output_name) + def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: assert isinstance(node.fun, itir.FunCall) assert len(node.fun.args) == 2 op_name = node.fun.args[0] assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # TODO: use type inference to determine the result type - dtype = dace.float64 + init_value = node.fun.args[1] + assert isinstance(init_value, itir.Literal) - # we store the value of reduce identity in the visitor context while visiting the input to reduction - # this value will be returned by the neighbors builtin function for skip values - prev_reduce_identity = self.reduce_identity - self.reduce_identity = SymbolExpr(reduce_identity.value, dtype) assert len(node.args) == 1 input_expr = self.visit(node.args[0]) assert isinstance(input_expr, MemletExpr | ValueExpr) - self.reduce_identity = prev_reduce_identity - input_desc = input_expr.node.desc(self.sdfg) - assert isinstance(input_desc, dace.data.Array) - input_subset = sbs.Range([(0, dim_size - 1, 1) for dim_size in input_desc.shape]) + # TODO: use type inference to determine the result type + result_dtype = dace.float64 - if len(input_desc.shape) > 1: - ndims = len(input_desc.shape) - 1 - reduce_axes = [ndims] - result_subset = sbs.Range(input_subset[0:ndims]) - else: - reduce_axes = None - result_subset = None + if isinstance(input_expr, (MaskedMemletExpr, MaskedValueExpr)): + return self._make_reduce_with_skip_values(op_name, init_value, result_dtype, input_expr) - reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_identity) + # handle below the reduction with full connectivity + input_desc = input_expr.node.desc(self.sdfg) + assert isinstance(input_desc, dace.data.Array) + nsdfg, field_name, _, output_name = build_reduce_sdfg( + op_name, init_value, result_dtype, input_desc + ) - if isinstance(input_expr, MemletExpr): - self._add_input_connection(input_expr.node, input_subset, reduce_node) - else: - self.state.add_nedge( + reduce_node = self.state.add_nested_sdfg(nsdfg, self.sdfg, {field_name}, {output_name}) + + if isinstance(input_expr, ValueExpr): + assert len(input_desc.shape) == 1 + self.state.add_edge( input_expr.node, + None, reduce_node, - dace.Memlet(data=input_expr.node.data, subset=input_subset), + field_name, + dace.Memlet.from_array(input_expr.node.data, input_desc), + ) + else: + ndims = len(input_expr.subset) + assert len(input_desc.shape) == ndims + 1 + local_size = input_desc.shape[ndims] + reduce_values_subset = sbs.Range.from_indices( + input_expr.subset + ) + sbs.Range.from_string(f"0:{local_size}") + self._add_input_connection( + input_expr.node, reduce_values_subset, reduce_node, field_name ) - return self._get_tasklet_result(dtype, reduce_node, None, result_subset) + return self._get_tasklet_result(result_dtype, reduce_node, output_name) def _split_shift_args( self, args: list[itir.Expr] @@ -527,6 +745,7 @@ def _make_cartesian_shift( # a new iterator with a shifted index along one dimension return IteratorExpr( it.field, + it.mask, it.dimensions, { dim: (new_index if dim == offset_dim.value else index) @@ -611,7 +830,7 @@ def _make_unstructured_shift( shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} - return IteratorExpr(it.field, it.dimensions, shifted_indices) + return IteratorExpr(it.field, it.mask, it.dimensions, shifted_indices) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun @@ -626,6 +845,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: # the iterator to be shifted is the argument to the function node it = self.visit(node.args[0]) assert isinstance(it, IteratorExpr) + # skip values (implemented as an array mask) not supported with shift operator + assert it.mask is None # first argument of the shift node is the offset provider assert isinstance(head[0], itir.OffsetLiteral) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 0e06944d58..6f6845c701 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -985,7 +985,7 @@ def test_gtir_reduce_with_skip_values(): for v2e_neighbors in connectivity_V2E.table ] - for stencil in [stencil_inlined]: + for stencil in [stencil_inlined, stencil_fieldview]: testee = itir.Program( id=f"neighbors_sum", function_definitions=[], From b1f4a478d54ae8a14785390081bacb0852534ca8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 27 May 2024 18:09:55 +0200 Subject: [PATCH 080/235] Merge 2 different implementations of reduce --- .../runners/dace_fieldview/gtir_to_tasklet.py | 324 +++++++++--------- 1 file changed, 157 insertions(+), 167 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 13c46d1729..70c7303b96 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -206,106 +206,101 @@ def build_neighbors_sdfg( def build_reduce_sdfg( - op_name: itir.SymRef, + code: str, + params: list[str], + acc_param: str, init_value: itir.Literal, result_dtype: dace.typeclass, values_desc: dace.data.Array, indices_desc: Optional[dace.data.Array] = None, -) -> tuple[dace.SDFG, str, Optional[str], str]: +) -> tuple[dace.SDFG, str, list[str], Optional[str]]: sdfg = dace.SDFG("reduce") neighbors_len = values_desc.shape[-1] - input_var, input_desc = sdfg.add_array("values", (neighbors_len,), values_desc.dtype) - acc_var, _ = sdfg.add_scalar("var", result_dtype) - + acc_var = "__acc" + assert acc_param != acc_var + input_vars = [f"__{p}" for p in params] if indices_desc: assert values_desc.shape == indices_desc.shape - indices_var, _ = sdfg.add_array("indices", (neighbors_len,), indices_desc.dtype) - - neighbor_idx = "__idx" - reduce_loop = dace.sdfg.state.LoopRegion( - label="reduce", - loop_var=neighbor_idx, - initialize_expr=f"{neighbor_idx} = 0", - condition_expr=f"{neighbor_idx} < {neighbors_len}", - update_expr=f"{neighbor_idx} = {neighbor_idx} + 1", - inverted=False, - ) - sdfg.add_node(reduce_loop) - reduce_state = reduce_loop.add_state("loop") - acc_code = MATH_BUILTINS_MAPPING[str(op_name)].format("acc", "val") - reduce_tasklet = reduce_state.add_tasklet( - "reduce_with_skip_values", - {"acc", "val", "idx"}, - {"res"}, - f"res = {acc_code} if idx != {neighbor_skip_value} else acc", - ) - reduce_state.add_edge( - reduce_state.add_access(acc_var), - None, - reduce_tasklet, - "acc", - dace.Memlet(data=acc_var, subset="0"), - ) + mask_var, _ = sdfg.add_array("__mask", (neighbors_len,), indices_desc.dtype) + tasklet_params = {acc_param, *params, "mask"} + tasklet_code = f"res = {code} if mask != {neighbor_skip_value} else {acc_param}" + else: + mask_var = None + tasklet_params = {acc_param, *params} + tasklet_code = f"res = {code}" + + neighbor_idx = "__idx" + reduce_loop = dace.sdfg.state.LoopRegion( + label="reduce", + loop_var=neighbor_idx, + initialize_expr=f"{neighbor_idx} = 0", + condition_expr=f"{neighbor_idx} < {neighbors_len}", + update_expr=f"{neighbor_idx} = {neighbor_idx} + 1", + inverted=False, + ) + sdfg.add_node(reduce_loop) + reduce_state = reduce_loop.add_state("loop") + + reduce_tasklet = reduce_state.add_tasklet( + "reduce", + tasklet_params, + {"res"}, + tasklet_code, + ) + + sdfg.add_scalar(acc_var, result_dtype) + reduce_state.add_edge( + reduce_state.add_access(acc_var), + None, + reduce_tasklet, + acc_param, + dace.Memlet(data=acc_var, subset="0"), + ) + + for inner_var, input_var in zip(params, input_vars): + sdfg.add_array(input_var, (neighbors_len,), values_desc.dtype) reduce_state.add_edge( reduce_state.add_access(input_var), None, reduce_tasklet, - "val", + inner_var, dace.Memlet(data=input_var, subset=neighbor_idx), ) + if indices_desc: reduce_state.add_edge( - reduce_state.add_access(indices_var), + reduce_state.add_access(mask_var), None, reduce_tasklet, - "idx", - dace.Memlet(data=indices_var, subset=neighbor_idx), - ) - reduce_state.add_edge( - reduce_tasklet, - "res", - reduce_state.add_access(acc_var), - None, - dace.Memlet(data=acc_var, subset="0"), - ) - - init_state = sdfg.add_state("init", is_start_block=True) - init_tasklet = init_state.add_tasklet( - "init_reduce", - {}, - {"val"}, - f"val = {init_value}", - ) - init_state.add_edge( - init_tasklet, - "val", - init_state.add_access(acc_var), - None, - dace.Memlet(data=acc_var, subset="0"), - ) - sdfg.add_edge(init_state, reduce_loop, dace.InterstateEdge()) - - else: - state = sdfg.add_state("main") - - reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = state.add_reduce(reduce_wcr, None, init_value) - - state.add_nedge( - state.add_access(input_var), - reduce_node, - dace.Memlet.from_array(input_var, input_desc), - ) - state.add_nedge( - reduce_node, - state.add_access(acc_var), - dace.Memlet(data=acc_var, subset="0"), + "mask", + dace.Memlet(data=mask_var, subset=neighbor_idx), ) + reduce_state.add_edge( + reduce_tasklet, + "res", + reduce_state.add_access(acc_var), + None, + dace.Memlet(data=acc_var, subset="0"), + ) - indices_var = None + init_state = sdfg.add_state("init", is_start_block=True) + init_tasklet = init_state.add_tasklet( + "init_reduce", + {}, + {"val"}, + f"val = {init_value}", + ) + init_state.add_edge( + init_tasklet, + "val", + init_state.add_access(acc_var), + None, + dace.Memlet(data=acc_var, subset="0"), + ) + sdfg.add_edge(init_state, reduce_loop, dace.InterstateEdge()) - return sdfg, input_var, indices_var, acc_var + return sdfg, acc_var, input_vars, mask_var class LambdaToTasklet(eve.NodeVisitor): @@ -555,111 +550,106 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), ) - def _make_reduce_with_skip_values( - self, - op_name: itir.SymRef, - init_value: itir.Literal, - result_dtype: dace.typeclass, - reduce_expr: MaskedMemletExpr | MaskedValueExpr, - ) -> ValueExpr: - values_desc = reduce_expr.node.desc(self.sdfg) - indices_desc = reduce_expr.mask.desc(self.sdfg) - assert indices_desc.shape == values_desc.shape + def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: + # TODO: use type inference to determine the result type + result_dtype = dace.float64 - nsdfg, field_name, index_name, output_name = build_reduce_sdfg( - op_name, init_value, result_dtype, values_desc, indices_desc + assert isinstance(node.fun, itir.FunCall) + assert len(node.fun.args) == 2 + reduce_acc_init = node.fun.args[1] + assert isinstance(reduce_acc_init, itir.Literal) + + if isinstance(node.fun.args[0], itir.SymRef): + assert len(node.args) == 1 + op_name = str(node.fun.args[0].id) + assert op_name in MATH_BUILTINS_MAPPING + reduce_acc_param = "acc" + reduce_params = ["val"] + reduce_code = MATH_BUILTINS_MAPPING[op_name].format("acc", "val") + else: + assert isinstance(node.fun.args[0], itir.Lambda) + assert len(node.args) >= 1 + # the +1 is for the accumulator value + assert len(node.fun.args[0].params) == len(node.args) + 1 + reduce_acc_param = str(node.fun.args[0].params[0].id) + reduce_params = [str(p.id) for p in node.fun.args[0].params[1:]] + reduce_code = PythonCodegen().visit(node.fun.args[0].expr) + + node_args: list[MemletExpr | ValueExpr] = [self.visit(arg) for arg in node.args] + reduce_args: list[tuple[str, MemletExpr | ValueExpr]] = list( + zip(reduce_params, node_args, strict=True) ) - reduce_node = self.state.add_nested_sdfg( - nsdfg, self.sdfg, {field_name, index_name}, {output_name} + _, first_expr = reduce_args[0] + values_desc = first_expr.node.desc(self.sdfg) + if isinstance(first_expr, (MaskedMemletExpr, MaskedValueExpr)): + indices_desc = first_expr.mask.desc(self.sdfg) + assert indices_desc.shape == values_desc.shape + else: + indices_desc = None + + nsdfg, sdfg_output, sdfg_inputs, mask_input = build_reduce_sdfg( + reduce_code, + reduce_params, + reduce_acc_param, + reduce_acc_init, + result_dtype, + values_desc, + indices_desc, ) - if isinstance(reduce_expr, MaskedMemletExpr): - assert isinstance(reduce_expr.subset, sbs.Indices) - ndims = len(reduce_expr.subset) - assert len(values_desc.shape) == ndims + 1 - local_size = values_desc.shape[ndims] - input_subset = sbs.Range.from_indices(reduce_expr.subset) + sbs.Range.from_string( - f"0:{local_size}" + if isinstance(first_expr, (MaskedMemletExpr, MaskedValueExpr)): + assert mask_input is not None + reduce_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, {*sdfg_inputs, mask_input}, {sdfg_output} ) - self._add_input_connection( - reduce_expr.node, - input_subset, - reduce_node, - field_name, + else: + assert mask_input is None + reduce_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, {*sdfg_inputs}, {sdfg_output} ) + + for sdfg_connector, (_, reduce_expr) in zip(sdfg_inputs, reduce_args, strict=True): + if isinstance(reduce_expr, MemletExpr): + assert isinstance(reduce_expr.subset, sbs.Indices) + ndims = len(reduce_expr.subset) + assert len(values_desc.shape) == ndims + 1 + local_size = values_desc.shape[ndims] + input_subset = sbs.Range.from_indices(reduce_expr.subset) + sbs.Range.from_string( + f"0:{local_size}" + ) + self._add_input_connection( + reduce_expr.node, + input_subset, + reduce_node, + sdfg_connector, + ) + else: + self.state.add_edge( + reduce_expr.node, + None, + reduce_node, + sdfg_connector, + dace.Memlet.from_array(reduce_expr.node.data, values_desc), + ) + + if isinstance(first_expr, MaskedMemletExpr): self._add_input_connection( - reduce_expr.mask, + first_expr.mask, input_subset, reduce_node, - index_name, - ) - else: - self.state.add_edge( - reduce_expr.node, - None, - reduce_node, - field_name, - dace.Memlet.from_array(reduce_expr.node.data, values_desc), - ) - self.state.add_edge( - reduce_expr.mask, - None, - reduce_node, - index_name, - dace.Memlet.from_array(reduce_expr.mask.data, indices_desc), + mask_input, ) - - return self._get_tasklet_result(result_dtype, reduce_node, output_name) - - def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: - assert isinstance(node.fun, itir.FunCall) - assert len(node.fun.args) == 2 - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - init_value = node.fun.args[1] - assert isinstance(init_value, itir.Literal) - - assert len(node.args) == 1 - input_expr = self.visit(node.args[0]) - assert isinstance(input_expr, MemletExpr | ValueExpr) - - # TODO: use type inference to determine the result type - result_dtype = dace.float64 - - if isinstance(input_expr, (MaskedMemletExpr, MaskedValueExpr)): - return self._make_reduce_with_skip_values(op_name, init_value, result_dtype, input_expr) - - # handle below the reduction with full connectivity - input_desc = input_expr.node.desc(self.sdfg) - assert isinstance(input_desc, dace.data.Array) - nsdfg, field_name, _, output_name = build_reduce_sdfg( - op_name, init_value, result_dtype, input_desc - ) - - reduce_node = self.state.add_nested_sdfg(nsdfg, self.sdfg, {field_name}, {output_name}) - - if isinstance(input_expr, ValueExpr): - assert len(input_desc.shape) == 1 + elif isinstance(first_expr, MaskedValueExpr): self.state.add_edge( - input_expr.node, + first_expr.mask, None, reduce_node, - field_name, - dace.Memlet.from_array(input_expr.node.data, input_desc), - ) - else: - ndims = len(input_expr.subset) - assert len(input_desc.shape) == ndims + 1 - local_size = input_desc.shape[ndims] - reduce_values_subset = sbs.Range.from_indices( - input_expr.subset - ) + sbs.Range.from_string(f"0:{local_size}") - self._add_input_connection( - input_expr.node, reduce_values_subset, reduce_node, field_name + mask_input, + dace.Memlet.from_array(first_expr.mask.data, indices_desc), ) - return self._get_tasklet_result(result_dtype, reduce_node, output_name) + return self._get_tasklet_result(result_dtype, reduce_node, sdfg_output) def _split_shift_args( self, args: list[itir.Expr] From 63e6e92ab3d5a295e5e8a202403961cf4a1498b8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 28 May 2024 13:50:41 +0200 Subject: [PATCH 081/235] Add support for reduce lambda function --- .../gtir_builtin_translators.py | 50 ++++--- .../runners/dace_fieldview/gtir_to_tasklet.py | 99 +++++++------ .../runners/dace_fieldview/utility.py | 35 ++++- .../runners_tests/test_dace_fieldview.py | 138 ++++++++++++++++-- 4 files changed, 237 insertions(+), 85 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index b40806ae61..2882e9bd32 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -93,8 +93,8 @@ class AsFieldOp(PrimitiveTranslator): stencil_expr: itir.Lambda stencil_args: list[SDFGFieldBuilder] - field_domain: dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] - field_type: ts.FieldType + field_domain: list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] + field_dtype: ts.ScalarType offset_provider: dict[str, Connectivity | Dimension] def __init__( @@ -116,16 +116,12 @@ def __init__( # the domain of the field operator is passed as second argument assert isinstance(domain_expr, itir.FunCall) - domain = dace_fieldview_util.get_domain(domain_expr) - # define field domain with all dimensions in alphabetical order - sorted_domain_dims = [Dimension(dim) for dim in sorted(domain.keys())] - # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - self.field_domain = domain - self.field_type = ts.FieldType(sorted_domain_dims, node_type) + self.field_domain = dace_fieldview_util.get_field_domain(domain_expr) + self.field_dtype = node_type self.stencil_expr = stencil_expr self.stencil_args = stencil_args @@ -146,17 +142,17 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { + indices: dict[Dimension, gtir_to_tasklet.IteratorIndexExpr] = { dim: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim)), + dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim.value)), index_dtype, ) - for dim in self.field_domain.keys() + for dim, _, _ in self.field_domain } iterator_arg = gtir_to_tasklet.IteratorExpr( data_node, mask_node, - [dim.value for dim in arg_type.dims], + arg_type.dims, indices, ) stencil_args.append(iterator_arg) @@ -175,31 +171,39 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field - field_dims = self.field_type.dims.copy() + field_dims = [dim for dim, _, _ in self.field_domain] field_shape = [ # diff between upper and lower bound - self.field_domain[dim.value][1] - self.field_domain[dim.value][0] - for dim in self.field_type.dims + (ub - lb) + for _, lb, ub in self.field_domain ] if isinstance(output_desc, dace.data.Array): # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) - field_dims.extend(Dimension(f"local_dim{i}") for i in range(len(output_desc.shape))) + assert isinstance(output_expr.field_type, ts.FieldType) + # TODO: enable `assert output_expr.field_type.dtype == self.field_dtype`, remove variable `dtype` + dtype = output_expr.field_type.dtype + field_dims.extend(output_expr.field_type.dims) field_shape.extend(output_desc.shape) + else: + assert isinstance(output_expr.field_type, ts.ScalarType) + # TODO: enable `assert output_expr.field_type == self.field_dtype`, remove variable `dtype` + dtype = output_expr.field_type - # TODO: use `self.field_type.field_dtype` without overriding `dtype` when type inference is in place - field_dtype = dace_fieldview_util.as_scalar_type(str(output_desc.dtype.as_numpy_dtype())) - field_type = ts.FieldType(field_dims, field_dtype) + # TODO: use `self.field_dtype` directly, without passing through `dtype` + field_type = ts.FieldType(field_dims, dtype) field_node = self.add_local_storage(field_type, field_shape) # assume tasklet with single output - output_subset = [DIMENSION_INDEX_FMT.format(dim=dim.value) for dim in self.field_type.dims] + output_subset = [ + DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain + ] if isinstance(output_desc, dace.data.Array): output_subset.extend(f"0:{size}" for size in output_desc.shape) # create map range corresponding to the field operator domain map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim): f"{lb}:{ub}" - for dim, (lb, ub) in self.field_domain.items() + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" + for dim, lb, ub in self.field_domain } me, mx = self.head_state.add_map("field_op", map_ranges) @@ -240,7 +244,7 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: else: mask_node = None - return [(field_node, self.field_type)], mask_node + return [(field_node, field_type)], mask_node class Select(PrimitiveTranslator): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 70c7303b96..b23f987ee9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -52,6 +52,7 @@ class ValueExpr: """Result of the computation implemented by a tasklet node.""" node: dace.nodes.AccessNode + field_type: ts.FieldType | ts.ScalarType IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @@ -63,8 +64,8 @@ class IteratorExpr: field: dace.nodes.AccessNode mask: Optional[dace.nodes.AccessNode] - dimensions: list[str] - indices: dict[str, IteratorIndexExpr] + dimensions: list[Dimension] + indices: dict[Dimension, IteratorIndexExpr] # Define alias for the elements needed to setup input connections to a map scope @@ -343,37 +344,52 @@ def _get_tasklet_result( dtype: dace.typeclass, src_node: dace.nodes.Node, src_connector: Optional[str] = None, - subset: Optional[sbs.Range] = None, + offset: Optional[str] = None, ) -> ValueExpr: - if subset: + data_type: ts.FieldType | ts.ScalarType + if offset: + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, Connectivity) var_name, _ = self.sdfg.add_array( - "var", subset.size(), dtype, transient=True, find_new_name=True + "var", (offset_provider.max_neighbors,), dtype, transient=True, find_new_name=True ) + var_subset = f"0:{offset_provider.max_neighbors}" + data_type = dace_fieldview_util.get_neighbors_field_type(offset, dtype) else: var_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) - subset = "0" + var_subset = "0" + data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) var_node = self.state.add_access(var_name) self.state.add_edge( src_node, src_connector, var_node, None, - dace.Memlet(data=var_node.data, subset=subset), + dace.Memlet(data=var_node.data, subset=var_subset), ) - return ValueExpr(var_node) + return ValueExpr(var_node, data_type) def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) if isinstance(it, IteratorExpr): + field_desc = it.field.desc(self.sdfg) + assert len(field_desc.shape) == len(it.dimensions) if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet - data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] + field_subset = sbs.Range( + [ + (it.indices[dim].value, it.indices[dim].value, 1) # type: ignore[union-attr] + if dim in it.indices + else (0, size - 1, 1) + for dim, size in zip(it.dimensions, field_desc.shape) + ] + ) return ( - MemletExpr(it.field, data_index) + MemletExpr(it.field, field_subset) if it.mask is None - else MaskedMemletExpr(it.field, data_index, it.mask) + else MaskedMemletExpr(it.field, field_subset, it.mask) ) else: @@ -381,17 +397,17 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert it.mask is None # we use a tasklet to perform dereferencing of a generic iterator - assert all(dim in it.indices.keys() for dim in it.dimensions) + assert all(dim in it.indices for dim in it.dimensions) field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_connectors = [ - INDEX_CONNECTOR_FMT.format(dim=dim) + INDEX_CONNECTOR_FMT.format(dim=dim.value) for dim, index in field_indices if not isinstance(index, SymbolExpr) ] index_internals = ",".join( str(index.value) if isinstance(index, SymbolExpr) - else INDEX_CONNECTOR_FMT.format(dim=dim) + else INDEX_CONNECTOR_FMT.format(dim=dim.value) for dim, index in field_indices ) deref_node = self.state.add_tasklet( @@ -401,12 +417,11 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for this field parameter - field_desc = it.field.desc(self.sdfg) field_fullset = sbs.Range.from_array(field_desc) self._add_input_connection(it.field, field_fullset, deref_node, "field") for dim, index_expr in field_indices: - deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) + deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim.value) if isinstance(index_expr, MemletExpr): self._add_input_connection( index_expr.node, @@ -444,11 +459,11 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis.value in it.dimensions - assert offset_provider.origin_axis.value in it.indices - origin_index = it.indices[offset_provider.origin_axis.value] + assert offset_provider.neighbor_axis in it.dimensions + assert offset_provider.origin_axis in it.indices + origin_index = it.indices[offset_provider.origin_axis] assert isinstance(origin_index, SymbolExpr) - assert offset_provider.origin_axis.value not in it.dimensions + assert offset_provider.origin_axis not in it.dimensions assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) field_desc = it.field.desc(self.sdfg) @@ -463,7 +478,7 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: field_array_shape = tuple( shape for dim, shape in zip(it.dimensions, field_desc.shape, strict=True) - if dim == offset_provider.neighbor_axis.value + if dim == offset_provider.neighbor_axis ) assert len(field_array_shape) == 1 @@ -486,8 +501,8 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: sbs.Range( [ (0, size - 1, 1) - if dim == offset_provider.neighbor_axis.value - else (it.indices[dim].value, it.indices[dim].value, 1) # type: ignore[union-attr] + if dim == offset_provider.neighbor_axis + else sbs.Indices(it.indices[dim].value) # type: ignore[union-attr] for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ] ), @@ -530,7 +545,7 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: output_name, neighbor_val_node, None, - dace.Memlet(data=neighbor_val_name, subset=f"0:{offset_provider.max_neighbors}"), + dace.Memlet.from_array(neighbor_val_name, neighbor_val_array), ) neighbor_idx_node = self.state.add_access(neighbor_idx_name) @@ -540,14 +555,17 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: neighbor_idx_node, ) - return MaskedValueExpr(neighbor_val_node, neighbor_idx_node) + neighbors_field_type = dace_fieldview_util.get_neighbors_field_type( + offset, field_desc.dtype + ) + return MaskedValueExpr(neighbor_val_node, neighbors_field_type, neighbor_idx_node) else: return self._get_tasklet_result( field_desc.dtype, neighbors_node, output_name, - subset=sbs.Range([(0, offset_provider.max_neighbors - 1, 1)]), + offset=offset, ) def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: @@ -611,16 +629,10 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: for sdfg_connector, (_, reduce_expr) in zip(sdfg_inputs, reduce_args, strict=True): if isinstance(reduce_expr, MemletExpr): - assert isinstance(reduce_expr.subset, sbs.Indices) - ndims = len(reduce_expr.subset) - assert len(values_desc.shape) == ndims + 1 - local_size = values_desc.shape[ndims] - input_subset = sbs.Range.from_indices(reduce_expr.subset) + sbs.Range.from_string( - f"0:{local_size}" - ) + assert isinstance(reduce_expr.subset, sbs.Subset) self._add_input_connection( reduce_expr.node, - input_subset, + reduce_expr.subset, reduce_node, sdfg_connector, ) @@ -636,7 +648,7 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: if isinstance(first_expr, MaskedMemletExpr): self._add_input_connection( first_expr.mask, - input_subset, + first_expr.subset, reduce_node, mask_input, ) @@ -674,10 +686,10 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: Dimension, offset_expr: IteratorIndexExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim.value in it.dimensions + assert offset_dim in it.dimensions new_index: SymbolExpr | ValueExpr - assert offset_dim.value in it.indices - index_expr = it.indices[offset_dim.value] + assert offset_dim in it.indices + index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr(index_expr.value + offset_expr.value, index_expr.dtype) @@ -737,10 +749,7 @@ def _make_cartesian_shift( it.field, it.mask, it.dimensions, - { - dim: (new_index if dim == offset_dim.value else index) - for dim, index in it.indices.items() - }, + {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, ) def _make_dynamic_neighbor_offset( @@ -795,11 +804,11 @@ def _make_unstructured_shift( offset_expr: IteratorIndexExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - neighbor_dim = connectivity.neighbor_axis.value - assert neighbor_dim in it.dimensions + assert connectivity.neighbor_axis in it.dimensions + neighbor_dim = connectivity.neighbor_axis assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis.value + origin_dim = connectivity.origin_axis assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index f9919e230e..6d10614a89 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,7 @@ import dace -from gt4py.next.common import Connectivity +from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet @@ -67,9 +67,9 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne } -def get_domain( +def get_field_domain( node: itir.Expr, -) -> dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: +) -> list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Specialized visit method for domain expressions. @@ -77,7 +77,7 @@ def get_domain( """ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - domain = {} + domain = [] for named_range in node.args: assert cpm.is_call_to(named_range, "named_range") assert len(named_range.args) == 3 @@ -88,11 +88,27 @@ def get_domain( sym_str = get_symbolic_expr(arg) sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) - domain[axis.value] = (bounds[0], bounds[1]) + size_value = str(bounds[1] - bounds[0]) + if size_value.isdigit(): + dim = Dimension(axis.value, DimensionKind.LOCAL) + else: + dim = Dimension(axis.value, DimensionKind.HORIZONTAL) + domain.append((dim, bounds[0], bounds[1])) return domain +def get_domain( + node: itir.Expr, +) -> dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + """ + Returns domain represented in dictionary form. + """ + field_domain = get_field_domain(node) + + return {dim.value: (lb, ub) for dim, lb, ub in field_domain} + + def get_symbolic_expr(node: itir.Expr) -> str: """ Specialized visit method for symbolic expressions. @@ -101,3 +117,12 @@ def get_symbolic_expr(node: itir.Expr) -> str: or symbolic array shape. """ return gtir_to_tasklet.PythonCodegen().visit(node) + + +def get_neighbors_field_type(offset: str, dtype: dace.typeclass) -> ts.FieldType: + """Utility function to obtain the descriptor for a local field of neighbors.""" + scalar_type = as_scalar_type(str(dtype.as_numpy_dtype())) + return ts.FieldType( + [Dimension(offset, DimensionKind.LOCAL)], + scalar_type, + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 6f6845c701..2b140929da 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -32,8 +32,6 @@ simple_mesh, skip_value_mesh, ) -from next_tests.integration_tests.cases import EField, IFloatField, VField - from functools import reduce import numpy as np import pytest @@ -873,13 +871,14 @@ def test_gtir_neighbors(): def test_gtir_reduce(): + init_value = np.random.rand() vertex_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), ) stencil_inlined = im.call( im.call("as_fieldop")( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(0)))( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( im.neighbors("V2E", "it") ) ), @@ -889,7 +888,9 @@ def test_gtir_reduce(): stencil_fieldview = im.call( im.call("as_fieldop")( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(0)))(im.deref("it")) + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) ), vertex_domain, ) @@ -908,13 +909,13 @@ def test_gtir_reduce(): e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ - reduce(lambda x, y: x + y, e[v2e_neighbors], 0.0) + reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) for v2e_neighbors in connectivity_V2E.table ] - for stencil in [stencil_inlined, stencil_fieldview]: + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): testee = itir.Program( - id=f"neighbors_sum", + id=f"reduce_{i}", function_definitions=[], params=[ itir.Sym(id="edges"), @@ -946,13 +947,14 @@ def test_gtir_reduce(): def test_gtir_reduce_with_skip_values(): + init_value = np.random.rand() vertex_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), ) stencil_inlined = im.call( im.call("as_fieldop")( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(0)))( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( im.neighbors("V2E", "it") ) ), @@ -962,7 +964,9 @@ def test_gtir_reduce_with_skip_values(): stencil_fieldview = im.call( im.call("as_fieldop")( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(0)))(im.deref("it")) + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) ), vertex_domain, ) @@ -981,13 +985,13 @@ def test_gtir_reduce_with_skip_values(): e = np.random.rand(SKIP_VALUE_MESH.num_edges) v_ref = [ - reduce(lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], 0.0) + reduce(lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value) for v2e_neighbors in connectivity_V2E.table ] - for stencil in [stencil_inlined, stencil_fieldview]: + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): testee = itir.Program( - id=f"neighbors_sum", + id=f"reduce_with_skip_values_{i}", function_definitions=[], params=[ itir.Sym(id="edges"), @@ -1016,3 +1020,113 @@ def test_gtir_reduce_with_skip_values(): **make_mesh_symbols(SKIP_VALUE_MESH), ) assert np.allclose(v, v_ref) + + +def test_gtir_reduce_with_lambda(): + init_value = np.random.rand() + edge_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "nedges"), + ) + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + v2e_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + im.call("named_range")( + itir.AxisLiteral(value=V2EDim.value), + 0, + SIMPLE_MESH.offset_provider["V2E"].max_neighbors, + ), + ) + reduce_lambda = im.lambda_("acc", "a", "b")(im.plus(im.multiplies_("a", "b"), "acc")) + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("itA", "itB")( + im.call(im.call("reduce")(reduce_lambda, im.literal_from_value(init_value)))( + im.neighbors("V2E", "itA"), im.neighbors("V2E", "itB") + ) + ), + vertex_domain, + ) + )( + "edges", + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.plus(im.deref("it"), 1)), + edge_domain, + ) + )("edges"), + ) + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("a_it", "b_it")( + im.call(im.call("reduce")(reduce_lambda, im.literal_from_value(init_value)))( + im.deref("a_it"), im.deref("b_it") + ) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.plus(im.deref("it"), 1)), + v2e_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ), + ) + + arg_types = [EFTYPE, VFTYPE, SIZE_TYPE, SIZE_TYPE] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, NeighborTable) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v_ref = [ + reduce(lambda x, y: x + y, e[v2e_neighbors] * (e[v2e_neighbors] + 1), init_value) + for v2e_neighbors in connectivity_V2E.table + ] + + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): + testee = itir.Program( + id=f"reduce_with_lambda_{i}", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nedges"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + + # new empty output field + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) From 107e295ad30a6aee1f6bcb14504afa78c6e10fc2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 08:29:33 +0200 Subject: [PATCH 082/235] Add support for neighbors masked array returned by select statements --- .../gtir_builtin_translators.py | 60 +++++++++- .../runners_tests/test_dace_fieldview.py | 111 +++++++++++++++++- 2 files changed, 163 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 2882e9bd32..982b5aa02d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -247,6 +247,23 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: return [(field_node, field_type)], mask_node +def _add_full_mask(mask_var: str, mask_desc: dace.data.Array, state: dace.SDFGState) -> None: + """Fill a connectivity table with a valid neighbor index, to mimic full connectivity mask.""" + state.add_mapped_tasklet( + "set_full_mask", + map_ranges={f"_i{i}": f"0:{size}" for i, size in enumerate(mask_desc.shape)}, + inputs={}, + code="val = 0", + outputs={ + "val": dace.Memlet( + data=mask_var, + subset=",".join(f"_i{i}" for i in range(len(mask_desc.shape))), + ) + }, + external_edges=True, + ) + + class Select(PrimitiveTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" @@ -319,9 +336,9 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: assert isinstance(true_br_node, dace.nodes.AccessNode) false_br_node, false_br_type = false_br assert isinstance(false_br_node, dace.nodes.AccessNode) - assert true_br_type == false_br_type - array_type = self.sdfg.arrays[true_br_node.data] - access_node = self.add_local_storage(true_br_type, array_type.shape) + desc = true_br_node.desc(self.sdfg) + assert false_br_node.desc(self.sdfg) == desc + access_node = self.add_local_storage(true_br_type, desc.shape) output_nodes.append((access_node, true_br_type)) data_name = access_node.data @@ -343,10 +360,41 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: ), ) - # TODO: add support for masked array values in select statements, if this lowering path is needed - assert not (true_br_mask or false_br_mask) + # Check if any of the true/false branches produces a mask array for neighbors with skip values; + # if only one branch produces the mask array, add a mapped tasklet on the other branch to produce + # a mask array filled with 0s to mimic all valid neighbor values. + if true_br_mask: + mask_desc = true_br_mask.desc(self.sdfg) + mask_var, _ = self.sdfg.add_temp_transient_like(mask_desc) + if false_br_mask: + assert mask_desc == false_br_mask.desc(self.sdfg) + false_state.add_nedge( + false_br_mask, + false_state.add_access(mask_var), + dace.Memlet.from_array(false_br_mask.data, mask_desc), + ) + else: + _add_full_mask(mask_var, mask_desc, false_state) + true_state.add_nedge( + true_br_mask, + true_state.add_access(mask_var), + dace.Memlet.from_array(true_br_mask.data, mask_desc), + ) + return output_nodes, self.head_state.add_access(mask_var) - return output_nodes, None + elif false_br_mask: + mask_desc = false_br_mask.desc(self.sdfg) + mask_var, _ = self.sdfg.add_temp_transient_like(mask_desc) + _add_full_mask(mask_var, mask_desc, true_state) + false_state.add_nedge( + false_br_mask, + false_state.add_access(mask_var), + dace.Memlet.from_array(false_br_mask.data, mask_desc), + ) + return output_nodes, self.head_state.add_access(mask_var) + + else: + return output_nodes, None class SymbolRef(PrimitiveTranslator): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 2b140929da..db3151b6ac 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -17,6 +17,7 @@ Note: this test module covers the fieldview flavour of ITIR. """ +import copy from gt4py.next.common import NeighborTable from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -128,7 +129,7 @@ def test_gtir_update(): im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( - id="gtir_copy", + id="gtir_update", function_definitions=[], params=[itir.Sym(id="x"), itir.Sym(id="size")], declarations=[], @@ -195,7 +196,7 @@ def test_gtir_sum2_sym(): im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( - id="sum_2fields", + id="sum_2fields_sym", function_definitions=[], params=[itir.Sym(id="x"), itir.Sym(id="z"), itir.Sym(id="size")], declarations=[], @@ -1130,3 +1131,109 @@ def test_gtir_reduce_with_lambda(): **make_mesh_symbols(SIMPLE_MESH), ) assert np.allclose(v, v_ref) + + +def test_gtir_reduce_with_select_neighbors(): + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + testee = itir.Program( + id=f"reduce_with_select_neighbors", + function_definitions=[], + params=[ + itir.Sym(id="cond"), + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.call( + im.call("select")( + im.deref("cond"), + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E_FULL", "it")), + vertex_domain, + ) + )("edges"), + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + ), + )() + ), + domain=vertex_domain, + target=itir.SymRef(id="vertices"), + ) + ], + ) + + connectivity_V2E_simple = SIMPLE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E_simple, NeighborTable) + connectivity_V2E_skip_values = copy.deepcopy(SKIP_VALUE_MESH.offset_provider["V2E"]) + assert isinstance(connectivity_V2E_skip_values, NeighborTable) + assert SKIP_VALUE_MESH.num_vertices <= SIMPLE_MESH.num_vertices + connectivity_V2E_skip_values.table = np.concatenate( + ( + connectivity_V2E_skip_values.table[:, 0 : connectivity_V2E_simple.max_neighbors], + connectivity_V2E_simple.table[SKIP_VALUE_MESH.num_vertices :, :], + ), + axis=0, + ) + connectivity_V2E_skip_values.max_neighbors = connectivity_V2E_simple.max_neighbors + + e = np.random.rand(SIMPLE_MESH.num_edges) + + arg_types = [ + ts.ScalarType(ts.ScalarKind.BOOL), + EFTYPE, + VFTYPE, + SIZE_TYPE, + ] + + for use_full in [False, True]: + sdfg = dace_backend.build_sdfg_from_gtir( + testee, + arg_types, + SIMPLE_MESH.offset_provider | {"V2E_FULL": connectivity_V2E_skip_values}, + ) + + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v_ref = [ + reduce( + lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + ) + for v2e_neighbors in ( + connectivity_V2E_simple.table if use_full else connectivity_V2E_skip_values.table + ) + ] + sdfg( + cond=np.bool_(use_full), + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E_skip_values.table, + connectivity_V2E_FULL=connectivity_V2E_simple.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + __connectivity_V2E_FULL_size_0=SIMPLE_MESH.num_edges, + __connectivity_V2E_FULL_size_1=connectivity_V2E_skip_values.max_neighbors, + __connectivity_V2E_FULL_stride_0=connectivity_V2E_skip_values.max_neighbors, + __connectivity_V2E_FULL_stride_1=1, + ) + assert np.allclose(v, v_ref) From 3c71efa50c84e308178d504e5bdc053a5b511f65 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 08:54:33 +0200 Subject: [PATCH 083/235] Import changes from neighbors branch --- .../gtir_builtin_translators.py | 68 +++++++++--------- .../runners/dace_fieldview/gtir_to_sdfg.py | 3 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 69 ++++++++++--------- .../runners/dace_fieldview/utility.py | 27 ++++++-- .../runners_tests/test_dace_fieldview.py | 4 +- 5 files changed, 96 insertions(+), 75 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 657b2b9c43..da23d3e864 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -35,6 +35,8 @@ SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] +DIMENSION_INDEX_FMT = "i_{dim}" + @dataclass(frozen=True) class PrimitiveTranslator(abc.ABC): @@ -91,8 +93,8 @@ class AsFieldOp(PrimitiveTranslator): stencil_expr: itir.Lambda stencil_args: list[SDFGFieldBuilder] - field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] - field_type: ts.FieldType + field_domain: list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] + field_dtype: ts.ScalarType offset_provider: dict[str, Connectivity | Dimension] def __init__( @@ -114,21 +116,16 @@ def __init__( # the domain of the field operator is passed as second argument assert isinstance(domain_expr, itir.FunCall) - domain = dace_fieldview_util.get_domain(domain_expr) - # define field domain with all dimensions in alphabetical order - sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) - # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - self.field_domain = domain - self.field_type = ts.FieldType(sorted_domain_dims, node_type) + self.field_domain = dace_fieldview_util.get_field_domain(domain_expr) + self.field_dtype = node_type self.stencil_expr = stencil_expr self.stencil_args = stencil_args def build(self) -> list[SDFGField]: - dimension_index_fmt = "i_{dim}" # type of variables used for field indexing index_dtype = dace.int32 # first visit the list of arguments and build a symbol map @@ -145,16 +142,16 @@ def build(self) -> list[SDFGField]: stencil_args.append(scalar_arg) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[str, gtir_to_tasklet.IteratorIndexExpr] = { - dim.value: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), + indices: dict[Dimension, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim.value)), index_dtype, ) - for dim in self.field_domain.keys() + for dim, _, _ in self.field_domain } iterator_arg = gtir_to_tasklet.IteratorExpr( data_node, - [dim.value for dim in arg_type.dims], + arg_type.dims, indices, ) stencil_args.append(iterator_arg) @@ -170,31 +167,36 @@ def build(self) -> list[SDFGField]: # the last transient node can be deleted # TODO: not needed to store the node `dtype` after type inference is in place - dtype = output_expr.node.desc(self.sdfg).dtype + output_desc = output_expr.node.desc(self.sdfg) self.head_state.remove_node(output_expr.node) # allocate local temporary storage for the result field + field_dims = [dim for dim, _, _ in self.field_domain] field_shape = [ # diff between upper and lower bound - self.field_domain[dim][1] - self.field_domain[dim][0] - for dim in self.field_type.dims + (ub - lb) + for _, lb, ub in self.field_domain ] - # TODO: use `self.field_type` without overriding `dtype` when type inference is in place - field_dtype = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - field_node = self.add_local_storage( - ts.FieldType(self.field_type.dims, field_dtype), field_shape - ) + if isinstance(output_desc, dace.data.Array): + raise NotImplementedError + else: + assert isinstance(output_expr.field_type, ts.ScalarType) + # TODO: enable `assert output_expr.field_type == self.field_dtype`, remove variable `dtype` + dtype = output_expr.field_type + + # TODO: use `self.field_dtype` directly, without passing through `dtype` + field_type = ts.FieldType(field_dims, dtype) + field_node = self.add_local_storage(field_type, field_shape) # assume tasklet with single output - output_index = ",".join( - dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims - ) - output_memlet = dace.Memlet(data=field_node.data, subset=output_index) + output_subset = [ + DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain + ] # create map range corresponding to the field operator domain map_ranges = { - dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" - for dim, (lb, ub) in self.field_domain.items() + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" + for dim, lb, ub in self.field_domain } me, mx = self.head_state.add_map("field_op", map_ranges) @@ -216,10 +218,10 @@ def build(self) -> list[SDFGField]: mx, field_node, src_conn=output_tasklet_connector, - memlet=output_memlet, + memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), ) - return [(field_node, self.field_type)] + return [(field_node, field_type)] class Select(PrimitiveTranslator): @@ -294,9 +296,9 @@ def build(self) -> list[SDFGField]: assert isinstance(true_br_node, dace.nodes.AccessNode) false_br_node, false_br_type = false_br assert isinstance(false_br_node, dace.nodes.AccessNode) - assert true_br_type == false_br_type - array_type = self.sdfg.arrays[true_br_node.data] - access_node = self.add_local_storage(true_br_type, array_type.shape) + desc = true_br_node.desc(self.sdfg) + assert false_br_node.desc(self.sdfg) == desc + access_node = self.add_local_storage(true_br_type, desc.shape) output_nodes.append((access_node, true_br_type)) data_name = access_node.data diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 43c4fce1a3..07adb89d6a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -214,7 +214,8 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) if isinstance(target_symbol_type, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims + f"{domain[dim.value][0]}:{domain[dim.value][1]}" + for dim in target_symbol_type.dims ) else: assert len(domain) == 0 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 6758c8b497..9ca904983d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -34,7 +34,7 @@ class MemletExpr: """Scalar or array data access thorugh a memlet.""" - source: dace.nodes.AccessNode + node: dace.nodes.AccessNode subset: sbs.Indices | sbs.Range @@ -51,6 +51,7 @@ class ValueExpr: """Result of the computation implemented by a tasklet node.""" node: dace.nodes.AccessNode + field_type: ts.FieldType | ts.ScalarType IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @@ -61,8 +62,8 @@ class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode - dimensions: list[str] - indices: dict[str, IteratorIndexExpr] + dimensions: list[Dimension] + indices: dict[Dimension, IteratorIndexExpr] # Define alias for the elements needed to setup input connections to a map scope @@ -74,6 +75,9 @@ class IteratorExpr: ] +INDEX_CONNECTOR_FMT = "__index_{dim}" + + MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -165,28 +169,35 @@ def _add_input_connection( self.input_connections.append((src, subset, dst, dst_connector)) def _get_tasklet_result( - self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str + self, + dtype: dace.typeclass, + src_node: dace.nodes.Tasklet, + src_connector: str, ) -> ValueExpr: - scalar_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) - scalar_node = self.state.add_access(scalar_name) + var_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) + var_subset = "0" + data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + var_node = self.state.add_access(var_name) self.state.add_edge( src_node, src_connector, - scalar_node, + var_node, None, - dace.Memlet(data=scalar_node.data, subset="0"), + dace.Memlet(data=var_node.data, subset=var_subset), ) - return ValueExpr(scalar_node) + return ValueExpr(var_node, data_type) def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) if isinstance(it, IteratorExpr): + field_desc = it.field.desc(self.sdfg) + assert len(field_desc.shape) == len(it.dimensions) if all(isinstance(index, SymbolExpr) for index in it.indices.values()): - # use direct field access through memlet subset - data_index = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] - return MemletExpr(it.field, data_index) + # when all indices are symblic expressions, we can perform direct field access through a memlet + field_subset = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] + return MemletExpr(it.field, field_subset) else: raise NotImplementedError @@ -195,19 +206,20 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert isinstance(it, MemletExpr) return it - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: - raise NotImplementedError - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) - elif cpm.is_call_to(node.fun, "shift"): - return self._visit_shift(node) - else: assert isinstance(node.fun, itir.SymRef) + # create a tasklet node implementing the builtin function + builtin_name = str(node.fun.id) + if builtin_name in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin_name] + else: + raise NotImplementedError(f"'{builtin_name}' not implemented.") + node_internals = [] node_connections: dict[str, MemletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): @@ -222,13 +234,8 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value # use the argument value without adding any connector node_internals.append(arg_expr.value) - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - code = fmt.format(*node_internals) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") + # use tasklet connectors as expression arguments + code = fmt.format(*node_internals) out_connector = "result" tasklet_node = self.state.add_tasklet( @@ -248,13 +255,11 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: - self._add_input_connection( - arg_expr.source, arg_expr.subset, tasklet_node, connector - ) + self._add_input_connection(arg_expr.node, arg_expr.subset, tasklet_node, connector) # TODO: use type inference to determine the result type if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): - dtype = node_connections["__inp_0"].source.desc(self.sdfg).dtype + dtype = node_connections["__inp_0"].node.desc(self.sdfg).dtype else: node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) dtype = dace_fieldview_util.as_dace_type(node_type) @@ -272,11 +277,9 @@ def visit_Lambda( if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.source.desc(self.sdfg).dtype + output_dtype = output_expr.node.desc(self.sdfg).dtype tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection( - output_expr.source, output_expr.subset, tasklet_node, "__inp" - ) + self._add_input_connection(output_expr.node, output_expr.subset, tasklet_node, "__inp") else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 366b45e199..c9111c70e9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,7 @@ import dace -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet @@ -63,9 +63,9 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne } -def get_domain( +def get_field_domain( node: itir.Expr, -) -> dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: +) -> list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Specialized visit method for domain expressions. @@ -73,23 +73,38 @@ def get_domain( """ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - domain = {} + domain = [] for named_range in node.args: assert cpm.is_call_to(named_range, "named_range") assert len(named_range.args) == 3 axis = named_range.args[0] assert isinstance(axis, itir.AxisLiteral) - dim = Dimension(axis.value) bounds = [] for arg in named_range.args[1:3]: sym_str = get_symbolic_expr(arg) sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) - domain[dim] = (bounds[0], bounds[1]) + size_value = str(bounds[1] - bounds[0]) + if size_value.isdigit(): + dim = Dimension(axis.value, DimensionKind.LOCAL) + else: + dim = Dimension(axis.value, DimensionKind.HORIZONTAL) + domain.append((dim, bounds[0], bounds[1])) return domain +def get_domain( + node: itir.Expr, +) -> dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + """ + Returns domain represented in dictionary form. + """ + field_domain = get_field_domain(node) + + return {dim.value: (lb, ub) for dim, lb, ub in field_domain} + + def get_symbolic_expr(node: itir.Expr) -> str: """ Specialized visit method for symbolic expressions. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 3344dc5bf5..77ce24bbe2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -83,7 +83,7 @@ def test_gtir_update(): im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( - id="gtir_copy", + id="gtir_update", function_definitions=[], params=[itir.Sym(id="x"), itir.Sym(id="size")], declarations=[], @@ -150,7 +150,7 @@ def test_gtir_sum2_sym(): im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) testee = itir.Program( - id="sum_2fields", + id="sum_2fields_sym", function_definitions=[], params=[itir.Sym(id="x"), itir.Sym(id="z"), itir.Sym(id="size")], declarations=[], From d0bd277a2d1b1a0d65b42fc3f7d6ab6b0f0f14ba Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 09:06:52 +0200 Subject: [PATCH 084/235] Import changes from neighbors branch --- .../ir_utils/common_pattern_matcher.py | 10 +++ .../next/iterator/transforms/unroll_reduce.py | 8 --- .../runners/dace_fieldview/gtir_to_tasklet.py | 37 +++++----- .../runners_tests/test_dace_fieldview.py | 67 ++++++++++--------- 4 files changed, 62 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 4933307c53..e3dac7a578 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -27,6 +27,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `shift(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "shift" + ) + + def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index b058fc0a7b..47b8556c4e 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -23,14 +23,6 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift -def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: - return ( - isinstance(arg, itir.FunCall) - and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="shift") - ) - - def _is_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return isinstance(arg, itir.FunCall) and arg.fun == itir.SymRef(id="neighbors") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 6a6c94c841..9a08da7d25 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -204,17 +204,17 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: else: # we use a tasklet to perform dereferencing of a generic iterator - assert all(dim in it.indices.keys() for dim in it.dimensions) + assert all(dim in it.indices for dim in it.dimensions) field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] index_connectors = [ - INDEX_CONNECTOR_FMT.format(dim=dim) + INDEX_CONNECTOR_FMT.format(dim=dim.value) for dim, index in field_indices if not isinstance(index, SymbolExpr) ] index_internals = ",".join( str(index.value) if isinstance(index, SymbolExpr) - else INDEX_CONNECTOR_FMT.format(dim=dim) + else INDEX_CONNECTOR_FMT.format(dim=dim.value) for dim, index in field_indices ) deref_node = self.state.add_tasklet( @@ -224,15 +224,14 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for this field parameter - field_desc = it.field.desc(self.sdfg) field_fullset = sbs.Range.from_array(field_desc) self._add_input_connection(it.field, field_fullset, deref_node, "field") for dim, index_expr in field_indices: - deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim) + deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim.value) if isinstance(index_expr, MemletExpr): self._add_input_connection( - index_expr.source, + index_expr.node, index_expr.subset, deref_node, deref_connector, @@ -279,10 +278,10 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: Dimension, offset_expr: IteratorIndexExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim.value in it.dimensions + assert offset_dim in it.dimensions new_index: SymbolExpr | ValueExpr - assert offset_dim.value in it.indices - index_expr = it.indices[offset_dim.value] + assert offset_dim in it.indices + index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr(index_expr.value + offset_expr.value, index_expr.dtype) @@ -313,9 +312,9 @@ def _make_cartesian_shift( for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): if input_connector == "index": - dtype = input_expr.source.desc(self.sdfg).dtype + dtype = input_expr.node.desc(self.sdfg).dtype self._add_input_connection( - input_expr.source, + input_expr.node, input_expr.subset, dynamic_offset_tasklet, input_connector, @@ -341,10 +340,7 @@ def _make_cartesian_shift( return IteratorExpr( it.field, it.dimensions, - { - dim: (new_index if dim == offset_dim.value else index) - for dim, index in it.indices.items() - }, + {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, ) def _make_dynamic_neighbor_offset( @@ -374,7 +370,7 @@ def _make_dynamic_neighbor_offset( ) if isinstance(offset_expr, MemletExpr): self._add_input_connection( - offset_expr.source, + offset_expr.node, offset_expr.subset, tasklet_node, "offset", @@ -399,11 +395,11 @@ def _make_unstructured_shift( offset_expr: IteratorIndexExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - neighbor_dim = connectivity.neighbor_axis.value - assert neighbor_dim in it.dimensions + assert connectivity.neighbor_axis in it.dimensions + neighbor_dim = connectivity.neighbor_axis assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis.value + origin_dim = connectivity.origin_axis assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -472,6 +468,9 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_applied_shift(node): + return self._visit_shift(node) + else: assert isinstance(node.fun, itir.SymRef) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index bd2d9edb05..d9c1022242 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -30,8 +30,6 @@ Vertex, simple_mesh, ) -from next_tests.integration_tests.cases import EField, IFloatField, VField - import numpy as np import pytest @@ -56,33 +54,36 @@ __z_stride_0=1, size=N, ) -CSYMBOLS = dict( - ncells=SIMPLE_MESH.num_cells, - nedges=SIMPLE_MESH.num_edges, - nvertices=SIMPLE_MESH.num_vertices, - __cells_size_0=SIMPLE_MESH.num_cells, - __cells_stride_0=1, - __edges_size_0=SIMPLE_MESH.num_edges, - __edges_stride_0=1, - __vertices_size_0=SIMPLE_MESH.num_vertices, - __vertices_stride_0=1, - __connectivity_C2E_size_0=SIMPLE_MESH.num_cells, - __connectivity_C2E_size_1=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=SIMPLE_MESH.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=SIMPLE_MESH.num_cells, - __connectivity_C2V_size_1=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=SIMPLE_MESH.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=SIMPLE_MESH.num_edges, - __connectivity_E2V_size_1=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=SIMPLE_MESH.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=SIMPLE_MESH.num_vertices, - __connectivity_V2E_size_1=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=SIMPLE_MESH.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_1=1, -) + + +def make_mesh_symbols(mesh: MeshDescriptor): + return dict( + ncells=mesh.num_cells, + nedges=mesh.num_edges, + nvertices=mesh.num_vertices, + __cells_size_0=mesh.num_cells, + __cells_stride_0=1, + __edges_size_0=mesh.num_edges, + __edges_stride_0=1, + __vertices_size_0=mesh.num_vertices, + __vertices_stride_0=1, + __connectivity_C2E_size_0=mesh.num_cells, + __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_stride_1=1, + __connectivity_C2V_size_0=mesh.num_cells, + __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_stride_1=1, + __connectivity_E2V_size_0=mesh.num_edges, + __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_stride_1=1, + __connectivity_V2E_size_0=mesh.num_vertices, + __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_stride_1=1, + ) def test_gtir_copy(): @@ -727,7 +728,7 @@ def test_gtir_connectivity_shift(): connectivity_C2E=connectivity_C2E.table, connectivity_C2V=connectivity_C2V.table, **FSYMBOLS, - **CSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), __ve_field_size_0=SIMPLE_MESH.num_vertices, __ve_field_size_1=SIMPLE_MESH.num_edges, __ve_field_stride_0=SIMPLE_MESH.num_edges, @@ -798,8 +799,8 @@ def test_gtir_connectivity_shift_chain(): connectivity_E2V=connectivity_E2V.table, connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **CSYMBOLS, - __edges_out_size_0=CSYMBOLS["__edges_size_0"], - __edges_out_stride_0=CSYMBOLS["__edges_stride_0"], + **make_mesh_symbols(SIMPLE_MESH), + __edges_out_size_0=SIMPLE_MESH.num_edges, + __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) From 2f75cfb799cca9318afadceddaa05fab80c23d9d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 09:43:51 +0200 Subject: [PATCH 085/235] Add debuginfo for ir.Program and ir.Stmt nodes --- .../runners/dace_fieldview/gtir_to_sdfg.py | 3 +++ .../runners/dace_fieldview/utility.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 07adb89d6a..1197fbe3b6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -164,6 +164,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: ) sdfg = dace.SDFG(node.id) + sdfg.debuginfo = dace_fieldview_util.debuginfo(node) entry_state = sdfg.add_state("program_entry", is_start_block=True) # declarations of temporaries result in transient array definitions in the SDFG @@ -185,7 +186,9 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): + # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") + head_state._debuginfo = dace_fieldview_util.debuginfo(stmt) self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index c9111c70e9..b9152f1fb3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Mapping +from typing import Any, Mapping, Optional import dace @@ -49,6 +49,21 @@ def as_scalar_type(typestr: str) -> ts.ScalarType: return ts.ScalarType(kind) +def debuginfo( + node: itir.Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return debuginfo + + def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: """ Filter offset providers of type `Connectivity`. From 085f307d9fe1f8e94f67420272dae27c7ca1c0df Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 10:08:22 +0200 Subject: [PATCH 086/235] Fix error in debuginfo --- .../program_processors/runners/dace_fieldview/gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 1197fbe3b6..94ddd463a4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -188,7 +188,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for i, stmt in enumerate(node.body): # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - head_state._debuginfo = dace_fieldview_util.debuginfo(stmt) + head_state._debuginfo = dace_fieldview_util.debuginfo(stmt, sdfg.debuginfo) self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() From dc1434ceffac4494e6f2d4b6606c2e1e374034fc Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 29 May 2024 10:43:54 +0200 Subject: [PATCH 087/235] Fix error in debuginfo (1) --- .../program_processors/runners/dace_fieldview/gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 94ddd463a4..8150446253 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -164,7 +164,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: ) sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_fieldview_util.debuginfo(node) + sdfg.debuginfo = dace_fieldview_util.debuginfo(node, sdfg.debuginfo) entry_state = sdfg.add_state("program_entry", is_start_block=True) # declarations of temporaries result in transient array definitions in the SDFG From 3769fb52225f6c119aa384f03f7a1b7385f76bbd Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 14 Jun 2024 10:56:17 +0200 Subject: [PATCH 088/235] Remove nested SDFG for neighbors builtin --- .../gtir_builtin_translators.py | 56 ++-- .../runners/dace_fieldview/gtir_to_tasklet.py | 287 +++++++----------- .../runners_tests/test_dace_fieldview.py | 44 ++- 3 files changed, 170 insertions(+), 217 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 9c967436b1..66c664fde3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -157,19 +157,33 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: ) stencil_args.append(iterator_arg) + # create map range corresponding to the field operator domain + map_ranges = { + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" + for dim, lb, ub in self.field_domain + } + me, mx = self.head_state.add_map("field_op", map_ranges) + # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet( + self.sdfg, self.head_state, me, self.offset_provider + ) + output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + output_desc = output_expr.node.desc(self.sdfg) # retrieve the tasklet node which writes the result - output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src - output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn - - # the last transient node can be deleted - # TODO: not needed to store the node `dtype` after type inference is in place - output_desc = output_expr.node.desc(self.sdfg) - self.head_state.remove_node(output_expr.node) + last_node = self.head_state.in_edges(output_expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = self.head_state.in_edges(output_expr.node)[0].src_conn + self.head_state.remove_node(output_expr.node) + if len(last_node.in_connectors) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + self.head_state.add_nedge(me, last_node, dace.Memlet()) + else: + last_node = output_expr.node + last_node_connector = None # allocate local temporary storage for the result field field_dims = [dim for dim, _, _ in self.field_domain] @@ -201,31 +215,11 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: if isinstance(output_desc, dace.data.Array): output_subset.extend(f"0:{size}" for size in output_desc.shape) - # create map range corresponding to the field operator domain - map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" - for dim, lb, ub in self.field_domain - } - me, mx = self.head_state.add_map("field_op", map_ranges) - - if len(input_connections) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - self.head_state.add_nedge(me, output_tasklet_node, dace.Memlet()) - else: - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - self.head_state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) self.head_state.add_memlet_path( - output_tasklet_node, + last_node, mx, field_node, - src_conn=output_tasklet_connector, + src_conn=last_node_connector, memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index b23f987ee9..901bc7a717 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -68,15 +68,6 @@ class IteratorExpr: indices: dict[Dimension, IteratorIndexExpr] -# Define alias for the elements needed to setup input connections to a map scope -InputConnection: TypeAlias = tuple[ - dace.nodes.AccessNode, - sbs.Range, - dace.nodes.Node, - Optional[str], -] - - @dataclass(frozen=True) class MaskedMemletExpr(MemletExpr): """Scalar or array data access thorugh a memlet.""" @@ -149,63 +140,6 @@ class MaskedValueExpr(ValueExpr): } -def build_neighbors_sdfg( - field_dtype: dace.typeclass, - field_shape: tuple[int], - neighbors_shape: tuple[int], - index_dtype: dace.typeclass, - with_skip_values: bool, -) -> tuple[dace.SDFG, str, str, str]: - assert len(field_shape) == len(neighbors_shape) - - sdfg = dace.SDFG("neighbors") - state = sdfg.add_state() - me, mx = state.add_map( - "neighbors", - {f"__idx_{i}": sbs.Range([(0, size - 1, 1)]) for i, size in enumerate(neighbors_shape)}, - ) - neighbor_index = ",".join(f"__idx_{i}" for i in range(len(neighbors_shape))) - - field_name, field_array = sdfg.add_array("field", field_shape, field_dtype) - index_name, _ = sdfg.add_array("indexes", neighbors_shape, index_dtype) - var_name, _ = sdfg.add_array("values", neighbors_shape, field_dtype) - index_node = state.add_access(index_name) - - if with_skip_values: - skip_value_code = f" if __index != {neighbor_skip_value} else {field_dtype}(0)" - else: - skip_value_code = "" - tasklet_node = state.add_tasklet( - "gather_neighbors", - {"__field", "__index"}, - {"__val"}, - "__val = __field[__index]" + skip_value_code, - ) - state.add_memlet_path( - state.add_access(field_name), - me, - tasklet_node, - dst_conn="__field", - memlet=dace.Memlet.from_array(field_name, field_array), - ) - state.add_memlet_path( - index_node, - me, - tasklet_node, - dst_conn="__index", - memlet=dace.Memlet(data=index_name, subset=neighbor_index), - ) - state.add_memlet_path( - tasklet_node, - mx, - state.add_access(var_name), - src_conn="__val", - memlet=dace.Memlet(data=var_name, subset=neighbor_index), - ) - - return sdfg, field_name, index_name, var_name - - def build_reduce_sdfg( code: str, params: list[str], @@ -314,7 +248,6 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -322,22 +255,30 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, + map_entry: dace.nodes.MapEntry, offset_provider: dict[str, Connectivity | Dimension], ): self.sdfg = sdfg self.state = state - self.input_connections = [] + self.map_entry = map_entry self.offset_provider = offset_provider self.symbol_map = {} - def _add_input_connection( + def _add_entry_memlet_path( self, - src: dace.nodes.AccessNode, - subset: sbs.Range, - dst: dace.nodes.Node, - dst_connector: Optional[str] = None, + *path_nodes: dace.nodes.Node, + memlet: Optional[dace.Memlet] = None, + src_conn: Optional[str] = None, + dst_conn: Optional[str] = None, ) -> None: - self.input_connections.append((src, subset, dst, dst_connector)) + self.state.add_memlet_path( + path_nodes[0], + self.map_entry, + *path_nodes[1:], + memlet=memlet, + src_conn=src_conn, + dst_conn=dst_conn, + ) def _get_tasklet_result( self, @@ -417,17 +358,21 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for this field parameter - field_fullset = sbs.Range.from_array(field_desc) - self._add_input_connection(it.field, field_fullset, deref_node, "field") + self._add_entry_memlet_path( + it.field, + deref_node, + dst_conn="field", + memlet=dace.Memlet.from_array(it.field.data, field_desc), + ) for dim, index_expr in field_indices: deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim.value) if isinstance(index_expr, MemletExpr): - self._add_input_connection( + self._add_entry_memlet_path( index_expr.node, - index_expr.subset, deref_node, - deref_connector, + dst_conn=deref_connector, + memlet=dace.Memlet(data=index_expr.node.data, subset=index_expr.subset), ) elif isinstance(index_expr, ValueExpr): @@ -460,10 +405,11 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) assert offset_provider.neighbor_axis in it.dimensions + assert offset_provider.neighbor_axis not in it.indices + assert offset_provider.origin_axis not in it.dimensions assert offset_provider.origin_axis in it.indices origin_index = it.indices[offset_provider.origin_axis] assert isinstance(origin_index, SymbolExpr) - assert offset_provider.origin_axis not in it.dimensions assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) field_desc = it.field.desc(self.sdfg) @@ -475,62 +421,65 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: connectivity_desc.transient = False connectivity_node = self.state.add_access(connectivity) - field_array_shape = tuple( - shape - for dim, shape in zip(it.dimensions, field_desc.shape, strict=True) - if dim == offset_provider.neighbor_axis + me, mx = self.state.add_map( + "neighbors", + dict(__neighbor_idx=f"0:{offset_provider.max_neighbors}"), ) - assert len(field_array_shape) == 1 - - # we build a nested SDFG to gather all neighbors for each point in the field domain - # it can be seen as a library node - nsdfg, field_name, index_name, output_name = build_neighbors_sdfg( - field_desc.dtype, - field_array_shape, - (offset_provider.max_neighbors,), - connectivity_desc.dtype, - offset_provider.has_skip_values, + index_connector = "__index" + if offset_provider.has_skip_values: + skip_value_code = ( + f" if {index_connector} != {neighbor_skip_value} else {field_desc.dtype}(0)" + ) + else: + skip_value_code = "" + index_internals = ",".join( + [ + it.indices[dim].value if dim != offset_provider.neighbor_axis else index_connector # type: ignore[union-attr] + for dim in it.dimensions + ] ) - - neighbors_node = self.state.add_nested_sdfg( - nsdfg, self.sdfg, {field_name, index_name}, {output_name} + tasklet_node = self.state.add_tasklet( + "gather_neighbors", + {"__field", index_connector}, + {"__val"}, + f"__val = __field[{index_internals}]" + skip_value_code, ) - - self._add_input_connection( + self._add_entry_memlet_path( it.field, - sbs.Range( - [ - (0, size - 1, 1) - if dim == offset_provider.neighbor_axis - else sbs.Indices(it.indices[dim].value) # type: ignore[union-attr] - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) - ] - ), - neighbors_node, - field_name, + me, + tasklet_node, + dst_conn="__field", + memlet=dace.Memlet.from_array(it.field.data, field_desc), ) - - self._add_input_connection( + self._add_entry_memlet_path( connectivity_node, - sbs.Range( - [ - (origin_index.value, origin_index.value, 1), - (0, offset_provider.max_neighbors - 1, 1), - ] + me, + tasklet_node, + dst_conn=index_connector, + memlet=dace.Memlet( + data=connectivity, subset=sbs.Indices([origin_index.value, "__neighbor_idx"]) ), - neighbors_node, - index_name, ) - + neighbor_val_name, neighbor_val_array = self.sdfg.add_array( + "neighbor_val", + (offset_provider.max_neighbors,), + field_desc.dtype, + transient=True, + find_new_name=True, + ) + neighbor_val_node = self.state.add_access(neighbor_val_name) + self.state.add_memlet_path( + tasklet_node, + mx, + neighbor_val_node, + src_conn="__val", + memlet=dace.Memlet(data=neighbor_val_name, subset="__neighbor_idx"), + ) + neighbors_field_type = dace_fieldview_util.get_neighbors_field_type( + offset, field_desc.dtype + ) if offset_provider.has_skip_values: # simulate pattern of masked array, using the connctivity table as a mask - neighbor_val_name, neighbor_val_array = self.sdfg.add_array( - "neighbor_val", - (offset_provider.max_neighbors,), - field_desc.dtype, - transient=True, - find_new_name=True, - ) neighbor_idx_name, neighbor_idx_array = self.sdfg.add_array( "neighbor_idx", (offset_provider.max_neighbors,), @@ -538,35 +487,19 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: transient=True, find_new_name=True, ) - - neighbor_val_node = self.state.add_access(neighbor_val_name) - self.state.add_edge( - neighbors_node, - output_name, - neighbor_val_node, - None, - dace.Memlet.from_array(neighbor_val_name, neighbor_val_array), - ) - neighbor_idx_node = self.state.add_access(neighbor_idx_name) - self._add_input_connection( + self._add_entry_memlet_path( connectivity_node, - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), neighbor_idx_node, - ) - - neighbors_field_type = dace_fieldview_util.get_neighbors_field_type( - offset, field_desc.dtype + memlet=dace.Memlet( + data=connectivity, + subset=f"{origin_index.value}, 0:{offset_provider.max_neighbors}", + ), ) return MaskedValueExpr(neighbor_val_node, neighbors_field_type, neighbor_idx_node) else: - return self._get_tasklet_result( - field_desc.dtype, - neighbors_node, - output_name, - offset=offset, - ) + return ValueExpr(neighbor_val_node, neighbors_field_type) def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: # TODO: use type inference to determine the result type @@ -630,11 +563,11 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: for sdfg_connector, (_, reduce_expr) in zip(sdfg_inputs, reduce_args, strict=True): if isinstance(reduce_expr, MemletExpr): assert isinstance(reduce_expr.subset, sbs.Subset) - self._add_input_connection( + self._add_entry_memlet_path( reduce_expr.node, - reduce_expr.subset, reduce_node, - sdfg_connector, + dst_conn=sdfg_connector, + memlet=dace.Memlet(data=reduce_expr.node.data, subset=reduce_expr.subset), ) else: self.state.add_edge( @@ -646,11 +579,11 @@ def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: ) if isinstance(first_expr, MaskedMemletExpr): - self._add_input_connection( + self._add_entry_memlet_path( first_expr.mask, - first_expr.subset, reduce_node, - mask_input, + dst_conn=mask_input, + memlet=dace.Memlet(data=first_expr.mask.data, subset=first_expr.subset), ) elif isinstance(first_expr, MaskedValueExpr): self.state.add_edge( @@ -721,11 +654,11 @@ def _make_cartesian_shift( if isinstance(input_expr, MemletExpr): if input_connector == "index": dtype = input_expr.node.desc(self.sdfg).dtype - self._add_input_connection( + self._add_entry_memlet_path( input_expr.node, - input_expr.subset, dynamic_offset_tasklet, - input_connector, + dst_conn=input_connector, + memlet=dace.Memlet(data=input_expr.node.data, subset=input_expr.subset), ) elif isinstance(input_expr, ValueExpr): if input_connector == "index": @@ -771,18 +704,20 @@ def _make_dynamic_neighbor_offset( {new_index_connector}, f"{new_index_connector} = table[{origin_index.value}, offset]", ) - self._add_input_connection( + self._add_entry_memlet_path( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, - "table", + dst_conn="table", + memlet=dace.Memlet.from_array( + offset_table_node.data, offset_table_node.desc(self.sdfg) + ), ) if isinstance(offset_expr, MemletExpr): - self._add_input_connection( + self._add_entry_memlet_path( offset_expr.node, - offset_expr.subset, tasklet_node, - "offset", + dst_conn="offset", + memlet=dace.Memlet(data=offset_expr.node.data, subset=offset_expr.subset), ) else: self.state.add_edge( @@ -933,7 +868,12 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: - self._add_input_connection(arg_expr.node, arg_expr.subset, tasklet_node, connector) + self._add_entry_memlet_path( + arg_expr.node, + tasklet_node, + dst_conn=connector, + memlet=dace.Memlet(data=arg_expr.node.data, subset=arg_expr.subset), + ) # TODO: use type inference to determine the result type if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): @@ -946,25 +886,32 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value def visit_Lambda( self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[InputConnection], ValueExpr]: + ) -> ValueExpr: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): - return self.input_connections, output_expr + return output_expr if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.node.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection(output_expr.node, output_expr.subset, tasklet_node, "__inp") + dtype = self.sdfg.arrays[output_expr.node.data].dtype + scalar_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + var, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) + result_node = self.state.add_access(var) + self._add_entry_memlet_path( + output_expr.node, + result_node, + memlet=dace.Memlet(data=output_expr.node.data, subset=output_expr.subset), + ) + return ValueExpr(result_node, scalar_type) else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype tasklet_node = self.state.add_tasklet( "write", {}, {"__out"}, f"__out = {output_expr.value}" ) - return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") + return self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index db3151b6ac..01432adc38 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -62,6 +62,18 @@ def make_mesh_symbols(mesh: MeshDescriptor): + C2E_size_0, C2E_size_1 = mesh.offset_provider["C2E"].table.shape + C2E_stride_0, C2E_stride_1 = C2E_size_1, 1 # mesh.offset_provider["C2E"].table.strides + + C2V_size_0, C2V_size_1 = mesh.offset_provider["C2V"].table.shape + C2V_stride_0, C2V_stride_1 = C2V_size_1, 1 # mesh.offset_provider["C2V"].table.strides + + E2V_size_0, E2V_size_1 = mesh.offset_provider["E2V"].table.shape + E2V_stride_0, E2V_stride_1 = E2V_size_1, 1 # mesh.offset_provider["E2V"].table.strides + + V2E_size_0, V2E_size_1 = mesh.offset_provider["V2E"].table.shape + V2E_stride_0, V2E_stride_1 = V2E_size_1, 1 # mesh.offset_provider["V2E"].table.strides + return dict( ncells=mesh.num_cells, nedges=mesh.num_edges, @@ -72,22 +84,22 @@ def make_mesh_symbols(mesh: MeshDescriptor): __edges_stride_0=1, __vertices_size_0=mesh.num_vertices, __vertices_stride_0=1, - __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_1=1, + __connectivity_C2E_size_0=C2E_size_0, + __connectivity_C2E_size_1=C2E_size_1, + __connectivity_C2E_stride_0=C2E_stride_0, + __connectivity_C2E_stride_1=C2E_stride_1, + __connectivity_C2V_size_0=C2V_size_0, + __connectivity_C2V_size_1=C2V_size_1, + __connectivity_C2V_stride_0=C2V_stride_0, + __connectivity_C2V_stride_1=C2V_stride_1, + __connectivity_E2V_size_0=E2V_size_0, + __connectivity_E2V_size_1=E2V_size_1, + __connectivity_E2V_stride_0=E2V_stride_0, + __connectivity_E2V_stride_1=E2V_stride_1, + __connectivity_V2E_size_0=V2E_size_0, + __connectivity_V2E_size_1=V2E_size_1, + __connectivity_V2E_stride_0=V2E_stride_0, + __connectivity_V2E_stride_1=V2E_stride_1, ) From b1b588769becafcfca6b02dc7c655b50cbca8102 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 26 Jun 2024 17:40:20 +0200 Subject: [PATCH 089/235] Remove masked array for skip values, rely on identity value --- .../gtir_builtin_translators.py | 108 ++----- .../dace_fieldview/gtir_dace_backend.py | 4 - .../runners/dace_fieldview/gtir_to_sdfg.py | 3 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 306 +++++------------- .../runners_tests/test_dace_fieldview.py | 148 ++++----- 5 files changed, 167 insertions(+), 402 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 66c664fde3..dad88ccfd6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -33,7 +33,7 @@ # Define aliases for return types SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] -SDFGFieldBuilder: TypeAlias = Callable[[], tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]] +SDFGFieldBuilder: TypeAlias = Callable[[gtir_to_tasklet.SymbolExpr | None], list[SDFGField]] DIMENSION_INDEX_FMT = "i_{dim}" @@ -44,14 +44,14 @@ class PrimitiveTranslator(abc.ABC): head_state: dace.SDFGState @final - def __call__(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + def __call__(self, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr]) -> list[SDFGField]: """The callable interface is used to build the dataflow graph. It allows to build the dataflow graph inside a given state starting from the innermost nodes, by propagating the intermediate results as access nodes to temporary local storage. """ - return self.build() + return self.build(reduce_identity) @final def add_local_storage( @@ -73,7 +73,7 @@ def add_local_storage( return self.head_state.add_access(name) @abc.abstractmethod - def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + def build(self, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr]) -> list[SDFGField]: """Creates the dataflow subgraph representing a GTIR builtin function. This method is used by derived classes to build a specialized subgraph @@ -125,13 +125,25 @@ def __init__( self.stencil_expr = stencil_expr self.stencil_args = stencil_args - def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + def get_reduce_identity(self) -> gtir_to_tasklet.SymbolExpr | None: + assert isinstance(self.stencil_expr, itir.Lambda) + if cpm.is_applied_reduce(self.stencil_expr.expr): + _, _, reduce_dentity = gtir_to_tasklet.get_reduce_params(self.stencil_expr.expr) + else: + reduce_dentity = None + return reduce_dentity + + def build(self, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr]) -> list[SDFGField]: # type of variables used for field indexing index_dtype = dace.int32 + + # retrieve the identity value if this is a reduce node + my_reduce_identity = self.get_reduce_identity() + # first visit the list of arguments and build a symbol map stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] for arg in self.stencil_args: - arg_nodes, mask_node = arg() + arg_nodes = arg(my_reduce_identity) assert len(arg_nodes) == 1 data_node, arg_type = arg_nodes[0] # require all argument nodes to be data access nodes (no symbols) @@ -151,7 +163,6 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: } iterator_arg = gtir_to_tasklet.IteratorExpr( data_node, - mask_node, arg_type.dims, indices, ) @@ -166,7 +177,7 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_to_tasklet.LambdaToTasklet( - self.sdfg, self.head_state, me, self.offset_provider + self.sdfg, self.head_state, me, self.offset_provider, reduce_identity ) output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) @@ -223,40 +234,7 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), ) - if isinstance(output_expr, gtir_to_tasklet.MaskedValueExpr): - # this is the case of neighbors with skip values: the value expression also contains the neighbor indices - mask_numpy_dtype = self.sdfg.arrays[output_expr.mask.data].dtype.as_numpy_dtype() - mask_dtype = dace_fieldview_util.as_scalar_type(str(mask_numpy_dtype)) - mask_node = self.add_local_storage(ts.FieldType(field_dims, mask_dtype), field_shape) - - self.head_state.add_memlet_path( - output_expr.mask, - mx, - mask_node, - memlet=dace.Memlet(data=mask_node.data, subset=",".join(output_subset)), - ) - - else: - mask_node = None - - return [(field_node, field_type)], mask_node - - -def _add_full_mask(mask_var: str, mask_desc: dace.data.Array, state: dace.SDFGState) -> None: - """Fill a connectivity table with a valid neighbor index, to mimic full connectivity mask.""" - state.add_mapped_tasklet( - "set_full_mask", - map_ranges={f"_i{i}": f"0:{size}" for i, size in enumerate(mask_desc.shape)}, - inputs={}, - code="val = 0", - outputs={ - "val": dace.Memlet( - data=mask_var, - subset=",".join(f"_i{i}" for i in range(len(mask_desc.shape))), - ) - }, - external_edges=True, - ) + return [(field_node, field_type)] class Select(PrimitiveTranslator): @@ -313,7 +291,7 @@ def __init__( false_expr, sdfg=sdfg, head_state=false_state ) - def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + def build(self, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr]) -> list[SDFGField]: # retrieve true/false states as predecessors of head state branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) assert len(branch_states) == 2 @@ -322,8 +300,8 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: else: false_state, true_state = branch_states - true_br_args, true_br_mask = self.true_br_builder() - false_br_args, false_br_mask = self.false_br_builder() + true_br_args = self.true_br_builder(reduce_identity) + false_br_args = self.false_br_builder(reduce_identity) output_nodes = [] for true_br, false_br in zip(true_br_args, false_br_args, strict=True): @@ -355,41 +333,7 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: ), ) - # Check if any of the true/false branches produces a mask array for neighbors with skip values; - # if only one branch produces the mask array, add a mapped tasklet on the other branch to produce - # a mask array filled with 0s to mimic all valid neighbor values. - if true_br_mask: - mask_desc = true_br_mask.desc(self.sdfg) - mask_var, _ = self.sdfg.add_temp_transient_like(mask_desc) - if false_br_mask: - assert mask_desc == false_br_mask.desc(self.sdfg) - false_state.add_nedge( - false_br_mask, - false_state.add_access(mask_var), - dace.Memlet.from_array(false_br_mask.data, mask_desc), - ) - else: - _add_full_mask(mask_var, mask_desc, false_state) - true_state.add_nedge( - true_br_mask, - true_state.add_access(mask_var), - dace.Memlet.from_array(true_br_mask.data, mask_desc), - ) - return output_nodes, self.head_state.add_access(mask_var) - - elif false_br_mask: - mask_desc = false_br_mask.desc(self.sdfg) - mask_var, _ = self.sdfg.add_temp_transient_like(mask_desc) - _add_full_mask(mask_var, mask_desc, true_state) - false_state.add_nedge( - false_br_mask, - false_state.add_access(mask_var), - dace.Memlet.from_array(false_br_mask.data, mask_desc), - ) - return output_nodes, self.head_state.add_access(mask_var) - - else: - return output_nodes, None + return output_nodes class SymbolRef(PrimitiveTranslator): @@ -409,7 +353,7 @@ def __init__( self.sym_name = sym_name self.sym_type = sym_type - def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: + def build(self, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr]) -> list[SDFGField]: if isinstance(self.sym_type, ts.FieldType): # add access node to current state sym_node = self.head_state.add_access(self.sym_name) @@ -433,4 +377,4 @@ def build(self) -> tuple[list[SDFGField], Optional[dace.nodes.AccessNode]]: dace.Memlet(data=sym_node.data, subset="0"), ) - return [(sym_node, self.sym_type)], None + return [(sym_node, self.sym_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py index bcbe390aca..c8c798292a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -13,7 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dace -from dace.sdfg import utils as sdutils from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir @@ -32,8 +31,5 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - sdfg.simplify() return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index aeeb9553d5..a6fc82809a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -133,8 +133,7 @@ def _visit_expression( field_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( node, sdfg=sdfg, head_state=head_state ) - results, mask_node = field_builder() - assert mask_node is None + results = field_builder(None) field_nodes = [] for node, _ in results: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 901bc7a717..a662e52dfb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -63,25 +63,10 @@ class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode - mask: Optional[dace.nodes.AccessNode] dimensions: list[Dimension] indices: dict[Dimension, IteratorIndexExpr] -@dataclass(frozen=True) -class MaskedMemletExpr(MemletExpr): - """Scalar or array data access thorugh a memlet.""" - - mask: dace.nodes.AccessNode - - -@dataclass(frozen=True) -class MaskedValueExpr(ValueExpr): - """Result of the computation implemented by a tasklet node.""" - - mask: dace.nodes.AccessNode - - INDEX_CONNECTOR_FMT = "__index_{dim}" @@ -140,102 +125,36 @@ class MaskedValueExpr(ValueExpr): } -def build_reduce_sdfg( - code: str, - params: list[str], - acc_param: str, - init_value: itir.Literal, - result_dtype: dace.typeclass, - values_desc: dace.data.Array, - indices_desc: Optional[dace.data.Array] = None, -) -> tuple[dace.SDFG, str, list[str], Optional[str]]: - sdfg = dace.SDFG("reduce") - - neighbors_len = values_desc.shape[-1] - - acc_var = "__acc" - assert acc_param != acc_var - input_vars = [f"__{p}" for p in params] - if indices_desc: - assert values_desc.shape == indices_desc.shape - mask_var, _ = sdfg.add_array("__mask", (neighbors_len,), indices_desc.dtype) - tasklet_params = {acc_param, *params, "mask"} - tasklet_code = f"res = {code} if mask != {neighbor_skip_value} else {acc_param}" - else: - mask_var = None - tasklet_params = {acc_param, *params} - tasklet_code = f"res = {code}" - - neighbor_idx = "__idx" - reduce_loop = dace.sdfg.state.LoopRegion( - label="reduce", - loop_var=neighbor_idx, - initialize_expr=f"{neighbor_idx} = 0", - condition_expr=f"{neighbor_idx} < {neighbors_len}", - update_expr=f"{neighbor_idx} = {neighbor_idx} + 1", - inverted=False, - ) - sdfg.add_node(reduce_loop) - reduce_state = reduce_loop.add_state("loop") - - reduce_tasklet = reduce_state.add_tasklet( - "reduce", - tasklet_params, - {"res"}, - tasklet_code, - ) - - sdfg.add_scalar(acc_var, result_dtype) - reduce_state.add_edge( - reduce_state.add_access(acc_var), - None, - reduce_tasklet, - acc_param, - dace.Memlet(data=acc_var, subset="0"), - ) - - for inner_var, input_var in zip(params, input_vars): - sdfg.add_array(input_var, (neighbors_len,), values_desc.dtype) - reduce_state.add_edge( - reduce_state.add_access(input_var), - None, - reduce_tasklet, - inner_var, - dace.Memlet(data=input_var, subset=neighbor_idx), - ) - if indices_desc: - reduce_state.add_edge( - reduce_state.add_access(mask_var), - None, - reduce_tasklet, - "mask", - dace.Memlet(data=mask_var, subset=neighbor_idx), - ) - reduce_state.add_edge( - reduce_tasklet, - "res", - reduce_state.add_access(acc_var), - None, - dace.Memlet(data=acc_var, subset="0"), - ) - - init_state = sdfg.add_state("init", is_start_block=True) - init_tasklet = init_state.add_tasklet( - "init_reduce", - {}, - {"val"}, - f"val = {init_value}", - ) - init_state.add_edge( - init_tasklet, - "val", - init_state.add_access(acc_var), - None, - dace.Memlet(data=acc_var, subset="0"), - ) - sdfg.add_edge(init_state, reduce_loop, dace.InterstateEdge()) - - return sdfg, acc_var, input_vars, mask_var +DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { + "minimum": dace.dtypes.ReductionType.Min, + "maximum": dace.dtypes.ReductionType.Max, + "plus": dace.dtypes.ReductionType.Sum, + "multiplies": dace.dtypes.ReductionType.Product, + "and_": dace.dtypes.ReductionType.Logical_And, + "or_": dace.dtypes.ReductionType.Logical_Or, + "xor_": dace.dtypes.ReductionType.Logical_Xor, + "minus": dace.dtypes.ReductionType.Sub, + "divides": dace.dtypes.ReductionType.Div, +} + + +def get_reduce_params(node: itir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: + # TODO: use type inference to determine the result type + dtype = dace.float64 + + assert isinstance(node.fun, itir.FunCall) + assert len(node.fun.args) == 2 + assert isinstance(node.fun.args[0], itir.SymRef) + op_name = str(node.fun.args[0]) + assert isinstance(node.fun.args[1], itir.Literal) + reduce_init = SymbolExpr(node.fun.args[1].value, dtype) + + if op_name not in DACE_REDUCTION_MAPPING: + raise RuntimeError(f"Reduction operation '{op_name}' not supported.") + identity_value = dace.dtypes.reduction_identity(dtype, DACE_REDUCTION_MAPPING[op_name]) + reduce_identity = SymbolExpr(identity_value, dtype) + + return op_name, reduce_init, reduce_identity class LambdaToTasklet(eve.NodeVisitor): @@ -250,6 +169,7 @@ class LambdaToTasklet(eve.NodeVisitor): state: dace.SDFGState offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] + reduce_identity: Optional[SymbolExpr] def __init__( self, @@ -257,12 +177,14 @@ def __init__( state: dace.SDFGState, map_entry: dace.nodes.MapEntry, offset_provider: dict[str, Connectivity | Dimension], + reduce_identity: Optional[SymbolExpr], ): self.sdfg = sdfg self.state = state self.map_entry = map_entry self.offset_provider = offset_provider self.symbol_map = {} + self.reduce_identity = reduce_identity def _add_entry_memlet_path( self, @@ -327,16 +249,9 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: for dim, size in zip(it.dimensions, field_desc.shape) ] ) - return ( - MemletExpr(it.field, field_subset) - if it.mask is None - else MaskedMemletExpr(it.field, field_subset, it.mask) - ) + return MemletExpr(it.field, field_subset) else: - # masked array not supported with indirect field access - assert it.mask is None - # we use a tasklet to perform dereferencing of a generic iterator assert all(dim in it.indices for dim in it.dimensions) field_indices = [(dim, it.indices[dim]) for dim in it.dimensions] @@ -427,9 +342,8 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: ) index_connector = "__index" if offset_provider.has_skip_values: - skip_value_code = ( - f" if {index_connector} != {neighbor_skip_value} else {field_desc.dtype}(0)" - ) + assert self.reduce_identity is not None + skip_value_code = f" if {index_connector} != {neighbor_skip_value} else {self.reduce_identity.dtype}({self.reduce_identity.value})" else: skip_value_code = "" index_internals = ",".join( @@ -478,123 +392,56 @@ def _visit_neighbors(self, node: itir.FunCall) -> ValueExpr: neighbors_field_type = dace_fieldview_util.get_neighbors_field_type( offset, field_desc.dtype ) - if offset_provider.has_skip_values: - # simulate pattern of masked array, using the connctivity table as a mask - neighbor_idx_name, neighbor_idx_array = self.sdfg.add_array( - "neighbor_idx", - (offset_provider.max_neighbors,), - connectivity_desc.dtype, - transient=True, - find_new_name=True, - ) - neighbor_idx_node = self.state.add_access(neighbor_idx_name) - self._add_entry_memlet_path( - connectivity_node, - neighbor_idx_node, - memlet=dace.Memlet( - data=connectivity, - subset=f"{origin_index.value}, 0:{offset_provider.max_neighbors}", - ), - ) - return MaskedValueExpr(neighbor_val_node, neighbors_field_type, neighbor_idx_node) - - else: - return ValueExpr(neighbor_val_node, neighbors_field_type) + return ValueExpr(neighbor_val_node, neighbors_field_type) def _visit_reduce(self, node: itir.FunCall) -> ValueExpr: - # TODO: use type inference to determine the result type - result_dtype = dace.float64 - - assert isinstance(node.fun, itir.FunCall) - assert len(node.fun.args) == 2 - reduce_acc_init = node.fun.args[1] - assert isinstance(reduce_acc_init, itir.Literal) - - if isinstance(node.fun.args[0], itir.SymRef): - assert len(node.args) == 1 - op_name = str(node.fun.args[0].id) - assert op_name in MATH_BUILTINS_MAPPING - reduce_acc_param = "acc" - reduce_params = ["val"] - reduce_code = MATH_BUILTINS_MAPPING[op_name].format("acc", "val") - else: - assert isinstance(node.fun.args[0], itir.Lambda) - assert len(node.args) >= 1 - # the +1 is for the accumulator value - assert len(node.fun.args[0].params) == len(node.args) + 1 - reduce_acc_param = str(node.fun.args[0].params[0].id) - reduce_params = [str(p.id) for p in node.fun.args[0].params[1:]] - reduce_code = PythonCodegen().visit(node.fun.args[0].expr) - - node_args: list[MemletExpr | ValueExpr] = [self.visit(arg) for arg in node.args] - reduce_args: list[tuple[str, MemletExpr | ValueExpr]] = list( - zip(reduce_params, node_args, strict=True) - ) - - _, first_expr = reduce_args[0] - values_desc = first_expr.node.desc(self.sdfg) - if isinstance(first_expr, (MaskedMemletExpr, MaskedValueExpr)): - indices_desc = first_expr.mask.desc(self.sdfg) - assert indices_desc.shape == values_desc.shape - else: - indices_desc = None - - nsdfg, sdfg_output, sdfg_inputs, mask_input = build_reduce_sdfg( - reduce_code, - reduce_params, - reduce_acc_param, - reduce_acc_init, - result_dtype, - values_desc, - indices_desc, - ) + op_name, reduce_init, reduce_identity = get_reduce_params(node) + dtype = reduce_identity.dtype - if isinstance(first_expr, (MaskedMemletExpr, MaskedValueExpr)): - assert mask_input is not None - reduce_node = self.state.add_nested_sdfg( - nsdfg, self.sdfg, {*sdfg_inputs, mask_input}, {sdfg_output} - ) + # we store the value of reduce identity in the visitor context while visiting the input to reduction + # this value will be returned by the neighbors builtin function for skip values + prev_reduce_identity = self.reduce_identity + self.reduce_identity = reduce_identity + assert len(node.args) == 1 + input_expr = self.visit(node.args[0]) + assert isinstance(input_expr, MemletExpr | ValueExpr) + self.reduce_identity = prev_reduce_identity + input_desc = input_expr.node.desc(self.sdfg) + assert isinstance(input_desc, dace.data.Array) + + if len(input_desc.shape) > 1: + assert isinstance(input_expr, MemletExpr) + ndims = len(input_desc.shape) - 1 + assert set(input_expr.subset.size()[0:ndims]) == {1} + reduce_axes = [ndims] else: - assert mask_input is None - reduce_node = self.state.add_nested_sdfg( - nsdfg, self.sdfg, {*sdfg_inputs}, {sdfg_output} - ) + reduce_axes = None + res_var, res_desc = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) + res_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - for sdfg_connector, (_, reduce_expr) in zip(sdfg_inputs, reduce_args, strict=True): - if isinstance(reduce_expr, MemletExpr): - assert isinstance(reduce_expr.subset, sbs.Subset) - self._add_entry_memlet_path( - reduce_expr.node, - reduce_node, - dst_conn=sdfg_connector, - memlet=dace.Memlet(data=reduce_expr.node.data, subset=reduce_expr.subset), - ) - else: - self.state.add_edge( - reduce_expr.node, - None, - reduce_node, - sdfg_connector, - dace.Memlet.from_array(reduce_expr.node.data, values_desc), - ) + reduce_wcr = "lambda x, y: " + MATH_BUILTINS_MAPPING[op_name].format("x", "y") + reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_init.value) - if isinstance(first_expr, MaskedMemletExpr): + if isinstance(input_expr, MemletExpr): self._add_entry_memlet_path( - first_expr.mask, + input_expr.node, reduce_node, - dst_conn=mask_input, - memlet=dace.Memlet(data=first_expr.mask.data, subset=first_expr.subset), + memlet=dace.Memlet(data=input_expr.node.data, subset=input_expr.subset), ) - elif isinstance(first_expr, MaskedValueExpr): - self.state.add_edge( - first_expr.mask, - None, + else: + self.state.add_nedge( + input_expr.node, reduce_node, - mask_input, - dace.Memlet.from_array(first_expr.mask.data, indices_desc), + dace.Memlet.from_array(input_expr.node.data, input_desc), ) - return self._get_tasklet_result(result_dtype, reduce_node, sdfg_output) + res_node = self.state.add_access(res_var) + self.state.add_nedge( + reduce_node, + res_node, + dace.Memlet.from_array(res_var, res_desc), + ) + return ValueExpr(res_node, res_type) def _split_shift_args( self, args: list[itir.Expr] @@ -680,7 +527,6 @@ def _make_cartesian_shift( # a new iterator with a shifted index along one dimension return IteratorExpr( it.field, - it.mask, it.dimensions, {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, ) @@ -764,7 +610,7 @@ def _make_unstructured_shift( shifted_indices = it.indices | {neighbor_dim: dynamic_offset_value} - return IteratorExpr(it.field, it.mask, it.dimensions, shifted_indices) + return IteratorExpr(it.field, it.dimensions, shifted_indices) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift_node = node.fun @@ -779,8 +625,6 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: # the iterator to be shifted is the argument to the function node it = self.visit(node.args[0]) assert isinstance(it, IteratorExpr) - # skip values (implemented as an array mask) not supported with shift operator - assert it.mask is None # first argument of the shift node is the offset provider assert isinstance(head[0], itir.OffsetLiteral) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 01432adc38..48515be4d9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -1035,11 +1035,8 @@ def test_gtir_reduce_with_skip_values(): assert np.allclose(v, v_ref) -def test_gtir_reduce_with_lambda(): +def test_gtir_reduce_dot_product(): init_value = np.random.rand() - edge_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "nedges"), - ) vertex_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Vertex.value), 0, "nvertices"), ) @@ -1051,98 +1048,83 @@ def test_gtir_reduce_with_lambda(): SIMPLE_MESH.offset_provider["V2E"].max_neighbors, ), ) - reduce_lambda = im.lambda_("acc", "a", "b")(im.plus(im.multiplies_("a", "b"), "acc")) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("itA", "itB")( - im.call(im.call("reduce")(reduce_lambda, im.literal_from_value(init_value)))( - im.neighbors("V2E", "itA"), im.neighbors("V2E", "itB") - ) - ), - vertex_domain, - ) - )( - "edges", - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.plus(im.deref("it"), 1)), - edge_domain, - ) - )("edges"), - ) - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("a_it", "b_it")( - im.call(im.call("reduce")(reduce_lambda, im.literal_from_value(init_value)))( - im.deref("a_it"), im.deref("b_it") - ) - ), - vertex_domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges"), - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.plus(im.deref("it"), 1)), - v2e_domain, + + testee = itir.Program( + id=f"reduce_dot_product", + function_definitions=[], + params=[ + itir.Sym(id="edges"), + itir.Sym(id="vertices"), + itir.Sym(id="nedges"), + itir.Sym(id="nvertices"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.multiplies_(im.deref("a"), im.deref("b"))), + v2e_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.plus(im.deref("it"), 1)), + v2e_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ), + ), + ), + domain=vertex_domain, + target=itir.SymRef(id="vertices"), ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges") - ), + ], ) arg_types = [EFTYPE, VFTYPE, SIZE_TYPE, SIZE_TYPE] connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, NeighborTable) + sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) + e = np.random.rand(SIMPLE_MESH.num_edges) + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) v_ref = [ reduce(lambda x, y: x + y, e[v2e_neighbors] * (e[v2e_neighbors] + 1), init_value) for v2e_neighbors in connectivity_V2E.table ] - for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): - testee = itir.Program( - id=f"reduce_with_lambda_{i}", - function_definitions=[], - params=[ - itir.Sym(id="edges"), - itir.Sym(id="vertices"), - itir.Sym(id="nedges"), - itir.Sym(id="nvertices"), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil, - domain=vertex_domain, - target=itir.SymRef(id="vertices"), - ) - ], - ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, SIMPLE_MESH.offset_provider) - - # new empty output field - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) - - sdfg( - edges=e, - vertices=v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - ) - assert np.allclose(v, v_ref) + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) def test_gtir_reduce_with_select_neighbors(): From a5b0f416301a5260592a226a4d05a7ababb8b43c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 28 Jun 2024 16:39:37 +0200 Subject: [PATCH 090/235] import changes from neighbors branch --- .../gtir_builtin_translators.py | 58 +++++++++--------- .../runners/dace_fieldview/gtir_to_tasklet.py | 60 +++++++++++-------- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index da23d3e864..7bf0348f7d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -156,19 +156,33 @@ def build(self) -> list[SDFGField]: ) stencil_args.append(iterator_arg) + # create map range corresponding to the field operator domain + map_ranges = { + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" + for dim, lb, ub in self.field_domain + } + me, mx = self.head_state.add_map("field_op", map_ranges) + # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(self.sdfg, self.head_state, self.offset_provider) - input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet( + self.sdfg, self.head_state, me, self.offset_provider + ) + output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + output_desc = output_expr.node.desc(self.sdfg) # retrieve the tasklet node which writes the result - output_tasklet_node = self.head_state.in_edges(output_expr.node)[0].src - output_tasklet_connector = self.head_state.in_edges(output_expr.node)[0].src_conn - - # the last transient node can be deleted - # TODO: not needed to store the node `dtype` after type inference is in place - output_desc = output_expr.node.desc(self.sdfg) - self.head_state.remove_node(output_expr.node) + last_node = self.head_state.in_edges(output_expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = self.head_state.in_edges(output_expr.node)[0].src_conn + self.head_state.remove_node(output_expr.node) + if len(last_node.in_connectors) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + self.head_state.add_nedge(me, last_node, dace.Memlet()) + else: + last_node = output_expr.node + last_node_connector = None # allocate local temporary storage for the result field field_dims = [dim for dim, _, _ in self.field_domain] @@ -192,32 +206,14 @@ def build(self) -> list[SDFGField]: output_subset = [ DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain ] + if isinstance(output_desc, dace.data.Array): + output_subset.extend(f"0:{size}" for size in output_desc.shape) - # create map range corresponding to the field operator domain - map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" - for dim, lb, ub in self.field_domain - } - me, mx = self.head_state.add_map("field_op", map_ranges) - - if len(input_connections) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - self.head_state.add_nedge(me, output_tasklet_node, dace.Memlet()) - else: - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - self.head_state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) self.head_state.add_memlet_path( - output_tasklet_node, + last_node, mx, field_node, - src_conn=output_tasklet_connector, + src_conn=last_node_connector, memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 9ca904983d..42d4e4feae 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from typing import TypeAlias +from typing import Optional, TypeAlias import dace import dace.subsets as sbs @@ -66,15 +66,6 @@ class IteratorExpr: indices: dict[Dimension, IteratorIndexExpr] -# Define alias for the elements needed to setup input connections to a map scope -InputConnection: TypeAlias = tuple[ - dace.nodes.AccessNode, - sbs.Range, - dace.nodes.Tasklet, - str, -] - - INDEX_CONNECTOR_FMT = "__index_{dim}" @@ -143,7 +134,6 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - input_connections: list[InputConnection] offset_provider: dict[str, Connectivity | Dimension] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -151,22 +141,30 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, + map_entry: dace.nodes.MapEntry, offset_provider: dict[str, Connectivity | Dimension], ): self.sdfg = sdfg self.state = state - self.input_connections = [] + self.map_entry = map_entry self.offset_provider = offset_provider self.symbol_map = {} - def _add_input_connection( + def _add_entry_memlet_path( self, - src: dace.nodes.AccessNode, - subset: sbs.Range, - dst: dace.nodes.Tasklet, - dst_connector: str, + *path_nodes: dace.nodes.Node, + memlet: Optional[dace.Memlet] = None, + src_conn: Optional[str] = None, + dst_conn: Optional[str] = None, ) -> None: - self.input_connections.append((src, subset, dst, dst_connector)) + self.state.add_memlet_path( + path_nodes[0], + self.map_entry, + *path_nodes[1:], + memlet=memlet, + src_conn=src_conn, + dst_conn=dst_conn, + ) def _get_tasklet_result( self, @@ -255,7 +253,12 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: - self._add_input_connection(arg_expr.node, arg_expr.subset, tasklet_node, connector) + self._add_entry_memlet_path( + arg_expr.node, + tasklet_node, + dst_conn=connector, + memlet=dace.Memlet(data=arg_expr.node.data, subset=arg_expr.subset), + ) # TODO: use type inference to determine the result type if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): @@ -268,25 +271,32 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value def visit_Lambda( self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[InputConnection], ValueExpr]: + ) -> ValueExpr: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): - return self.input_connections, output_expr + return output_expr if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.node.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_connection(output_expr.node, output_expr.subset, tasklet_node, "__inp") + dtype = self.sdfg.arrays[output_expr.node.data].dtype + scalar_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) + var, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) + result_node = self.state.add_access(var) + self._add_entry_memlet_path( + output_expr.node, + result_node, + memlet=dace.Memlet(data=output_expr.node.data, subset=output_expr.subset), + ) + return ValueExpr(result_node, scalar_type) else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype tasklet_node = self.state.add_tasklet( "write", {}, {"__out"}, f"__out = {output_expr.value}" ) - return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") + return self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) From c61e796144af9dcbe2044894b53bb27e644c8c76 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 28 Jun 2024 16:46:45 +0200 Subject: [PATCH 091/235] import changes from neighbors branch --- .../gtir_builtin_translators.py | 1 + .../runners/dace_fieldview/gtir_to_tasklet.py | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 7bf0348f7d..bbac2279bc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -315,6 +315,7 @@ def build(self) -> list[SDFGField]: false_br_output_node.data, false_br_output_node.desc(self.sdfg) ), ) + return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 8357815547..1dc51e22e7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -220,17 +220,21 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for this field parameter - field_fullset = sbs.Range.from_array(field_desc) - self._add_input_connection(it.field, field_fullset, deref_node, "field") + self._add_entry_memlet_path( + it.field, + deref_node, + dst_conn="field", + memlet=dace.Memlet.from_array(it.field.data, field_desc), + ) for dim, index_expr in field_indices: deref_connector = INDEX_CONNECTOR_FMT.format(dim=dim.value) if isinstance(index_expr, MemletExpr): - self._add_input_connection( + self._add_entry_memlet_path( index_expr.node, - index_expr.subset, deref_node, - deref_connector, + dst_conn=deref_connector, + memlet=dace.Memlet(data=index_expr.node.data, subset=index_expr.subset), ) elif isinstance(index_expr, ValueExpr): @@ -309,11 +313,11 @@ def _make_cartesian_shift( if isinstance(input_expr, MemletExpr): if input_connector == "index": dtype = input_expr.node.desc(self.sdfg).dtype - self._add_input_connection( + self._add_entry_memlet_path( input_expr.node, - input_expr.subset, dynamic_offset_tasklet, - input_connector, + dst_conn=input_connector, + memlet=dace.Memlet(data=input_expr.node.data, subset=input_expr.subset), ) elif isinstance(input_expr, ValueExpr): if input_connector == "index": @@ -358,18 +362,20 @@ def _make_dynamic_neighbor_offset( {new_index_connector}, f"{new_index_connector} = table[{origin_index.value}, offset]", ) - self._add_input_connection( + self._add_entry_memlet_path( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, - "table", + dst_conn="table", + memlet=dace.Memlet.from_array( + offset_table_node.data, offset_table_node.desc(self.sdfg) + ), ) if isinstance(offset_expr, MemletExpr): - self._add_input_connection( + self._add_entry_memlet_path( offset_expr.node, - offset_expr.subset, tasklet_node, - "offset", + dst_conn="offset", + memlet=dace.Memlet(data=offset_expr.node.data, subset=offset_expr.subset), ) else: self.state.add_edge( From f4d9d898b63accad19408231caa68b85c12d1012 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 3 Jul 2024 09:58:29 +0200 Subject: [PATCH 092/235] Let's see what auto opt can do. --- .../dace_fieldview/gtir_dace_backend.py | 1 + .../transformations/__init__.py | 20 +++++++ .../transformations/auto_opt.py | 52 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py index c8c798292a..449d0befcb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_to_sdfg as gtir_dace_translator, + transformations, # noqa: F401 [unused-import] # For development. ) from gt4py.next.type_system import type_specifications as ts diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py new file mode 100644 index 0000000000..4090950e5f --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -0,0 +1,20 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from .auto_opt import dace_auto_optimize + + +__all__ = [ + "dace_auto_optimize", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py new file mode 100644 index 0000000000..5bab359bef --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -0,0 +1,52 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Fast access to the auto optimization on DaCe.""" + +from typing import Any + +import dace + + +def dace_auto_optimize( + sdfg: dace.SDFG, + device: dace.DeviceType = dace.DeviceType.CPU, + **kwargs: Any, +) -> dace.SDFG: + """This is a convenient wrapper arround DaCe's `auto_optimize` function. + + By default it uses the `CPU` device type. Furthermore, it will first run the + `{In, Out}LocalStorage` transformations of the SDFG. The reason for this is that + empirical observations have shown, that the current auto optimizer has problems + in certain cases and this should prevent some of them. + + Args: + sdfg: The SDFG that should be optimized in place. + device: the device for which optimizations should be done, defaults to CPU. + kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. + """ + from dace.transformation.auto.auto_optimize import auto_optimize as _auto_optimize + from dace.transformation.dataflow import InLocalStorage, OutLocalStorage + + # Now put output storages everywhere to make auto optimizer less likely to fail. + sdfg.apply_transformations_repeated([InLocalStorage, OutLocalStorage]) + + # Now the optimization. + sdfg = _auto_optimize(sdfg, device=device, **kwargs) + + # Now the simplification step. + # This should get rid of some of teh additional transients we have added. + sdfg.simplify() + + return sdfg From 931801106d53a3e7e63067f8a7eea5d46bd1d772 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 09:02:32 +0200 Subject: [PATCH 093/235] Import changes from branch dace-fieldview-neighbors --- .../gtir_builtin_translators.py | 22 ++--- .../dace_fieldview/gtir_dace_backend.py | 8 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 21 ++--- .../runners_tests/test_dace_fieldview.py | 94 +++++++++---------- 4 files changed, 66 insertions(+), 79 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 7bf0348f7d..4b89656161 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -36,6 +36,7 @@ SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] DIMENSION_INDEX_FMT = "i_{dim}" +ITERATOR_INDEX_DTYPE = dace.int32 # type of iterator indexes @dataclass(frozen=True) @@ -126,8 +127,6 @@ def __init__( self.stencil_args = stencil_args def build(self) -> list[SDFGField]: - # type of variables used for field indexing - index_dtype = dace.int32 # first visit the list of arguments and build a symbol map stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] for arg in self.stencil_args: @@ -145,7 +144,7 @@ def build(self) -> list[SDFGField]: indices: dict[Dimension, gtir_to_tasklet.IteratorIndexExpr] = { dim: gtir_to_tasklet.SymbolExpr( dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim.value)), - index_dtype, + ITERATOR_INDEX_DTYPE, ) for dim, _, _ in self.field_domain } @@ -262,13 +261,15 @@ def __init__( # expect true branch as second argument true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) + sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) sdfg.add_edge(true_state, state, dace.InterstateEdge()) self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}"))) + sdfg.add_edge( + select_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")) + ) sdfg.add_edge(false_state, state, dace.InterstateEdge()) self.false_br_builder = dataflow_builder.visit( false_expr, sdfg=sdfg, head_state=false_state @@ -290,7 +291,7 @@ def build(self) -> list[SDFGField]: for true_br, false_br in zip(true_br_args, false_br_args, strict=True): true_br_node, true_br_type = true_br assert isinstance(true_br_node, dace.nodes.AccessNode) - false_br_node, false_br_type = false_br + false_br_node, _ = false_br assert isinstance(false_br_node, dace.nodes.AccessNode) desc = true_br_node.desc(self.sdfg) assert false_br_node.desc(self.sdfg) == desc @@ -302,19 +303,16 @@ def build(self) -> list[SDFGField]: true_state.add_nedge( true_br_node, true_br_output_node, - dace.Memlet.from_array( - true_br_output_node.data, true_br_output_node.desc(self.sdfg) - ), + dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), ) false_br_output_node = false_state.add_access(data_name) false_state.add_nedge( false_br_node, false_br_output_node, - dace.Memlet.from_array( - false_br_output_node.data, false_br_output_node.desc(self.sdfg) - ), + dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), ) + return output_nodes diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py index c8c798292a..bdf6d401c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -19,15 +19,17 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_to_sdfg as gtir_dace_translator, ) -from gt4py.next.type_system import type_specifications as ts def build_sdfg_from_gtir( program: itir.Program, - arg_types: list[ts.DataType], offset_provider: dict[str, Connectivity | Dimension], ) -> dace.SDFG: - sdfg_genenerator = gtir_dace_translator.GTIRToSDFG(arg_types, offset_provider) + """ + TODO: enable type inference + program = itir_type_inference.infer(program, offset_provider=offset_provider) + """ + sdfg_genenerator = gtir_dace_translator.GTIRToSDFG(offset_provider) sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 8150446253..172fc63835 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -45,16 +45,13 @@ class GTIRToSDFG(eve.NodeVisitor): from where to continue building the SDFG. """ - param_types: list[ts.DataType] offset_provider: dict[str, Connectivity | Dimension] symbol_types: dict[str, ts.FieldType | ts.ScalarType] def __init__( self, - param_types: list[ts.DataType], offset_provider: dict[str, Connectivity | Dimension], ): - self.param_types = param_types self.offset_provider = offset_provider self.symbol_types = {} @@ -126,7 +123,9 @@ def _visit_expression( Returns a list of array nodes containing the result fields. - TODO: do we need to return the GT4Py `FieldType`/`ScalarType`? + TODO: Do we need to return the GT4Py `FieldType`/`ScalarType`? It is needed + in case the transient arrays containing the expression result are not guaranteed + to have the same memory layout as the target array. """ field_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( node, sdfg=sdfg, head_state=head_state @@ -158,11 +157,6 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") - if len(node.params) != len(self.param_types): - raise RuntimeError( - "The provided list of parameter types has different length than SDFG parameter list." - ) - sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_fieldview_util.debuginfo(node, sdfg.debuginfo) entry_state = sdfg.add_state("program_entry", is_start_block=True) @@ -174,15 +168,14 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: temp_symbols |= self._add_storage_for_temporary(decl) # define symbols for shape and offsets of temporary arrays as interstate edge symbols - # TODO(edopao): use new `add_state_after` function available in next dace release - head_state = sdfg.add_state_after(entry_state, "init_temps") - sdfg.edges_between(entry_state, head_state)[0].assignments = temp_symbols + head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols) else: head_state = entry_state # add non-transient arrays and/or SDFG symbols for the program arguments - for param, type_ in zip(node.params, self.param_types, strict=True): - self._add_storage(sdfg, str(param.id), type_) + for param in node.params: + assert isinstance(param.type, ts.DataType) + self._add_storage(sdfg, str(param.id), param.type) # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 77ce24bbe2..f830d68da1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -52,7 +52,11 @@ def test_gtir_copy(): testee = itir.Program( id="gtir_copy", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="size")], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="y", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], declarations=[], body=[ itir.SetAt( @@ -71,8 +75,7 @@ def test_gtir_copy(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) sdfg(x=a, y=b, **FSYMBOLS) assert np.allclose(a, b) @@ -85,7 +88,10 @@ def test_gtir_update(): testee = itir.Program( id="gtir_update", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="size")], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], declarations=[], body=[ itir.SetAt( @@ -104,8 +110,7 @@ def test_gtir_update(): a = np.random.rand(N) ref = a + 1.0 - arg_types = [IFTYPE, SIZE_TYPE] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) sdfg(x=a, **FSYMBOLS) assert np.allclose(a, ref) @@ -118,7 +123,12 @@ def test_gtir_sum2(): testee = itir.Program( id="sum_2fields", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="y"), itir.Sym(id="z"), itir.Sym(id="size")], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="y", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], declarations=[], body=[ itir.SetAt( @@ -138,8 +148,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) sdfg(x=a, y=b, z=c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -152,7 +161,11 @@ def test_gtir_sum2_sym(): testee = itir.Program( id="sum_2fields_sym", function_definitions=[], - params=[itir.Sym(id="x"), itir.Sym(id="z"), itir.Sym(id="size")], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], declarations=[], body=[ itir.SetAt( @@ -171,8 +184,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - arg_types = [IFTYPE, IFTYPE, SIZE_TYPE] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) sdfg(x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + a)) @@ -209,18 +221,16 @@ def test_gtir_sum3(): b = np.random.rand(N) c = np.random.rand(N) - arg_types = [IFTYPE, IFTYPE, IFTYPE, IFTYPE, SIZE_TYPE] - for i, stencil in enumerate([stencil1, stencil2]): testee = itir.Program( id=f"sum_3fields_{i}", function_definitions=[], params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="size"), + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="y", type=IFTYPE), + itir.Sym(id="w", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ @@ -232,7 +242,7 @@ def test_gtir_sum3(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) d = np.empty_like(a) @@ -248,13 +258,13 @@ def test_gtir_select(): id="select_2sums", function_definitions=[], params=[ - itir.Sym(id="x"), - itir.Sym(id="y"), - itir.Sym(id="w"), - itir.Sym(id="z"), - itir.Sym(id="cond"), - itir.Sym(id="scalar"), - itir.Sym(id="size"), + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="y", type=IFTYPE), + itir.Sym(id="w", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="cond", type=ts.ScalarType(ts.ScalarKind.BOOL)), + itir.Sym(id="scalar", type=ts.ScalarType(ts.ScalarKind.FLOAT64)), + itir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ @@ -294,16 +304,7 @@ def test_gtir_select(): b = np.random.rand(N) c = np.random.rand(N) - arg_types = [ - IFTYPE, - IFTYPE, - IFTYPE, - IFTYPE, - ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.FLOAT64), - SIZE_TYPE, - ] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) for s in [False, True]: d = np.empty_like(a) @@ -319,11 +320,11 @@ def test_gtir_select_nested(): id="select_nested", function_definitions=[], params=[ - itir.Sym(id="x"), - itir.Sym(id="z"), - itir.Sym(id="cond_1"), - itir.Sym(id="cond_2"), - itir.Sym(id="size"), + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="cond_1", type=ts.ScalarType(ts.ScalarKind.BOOL)), + itir.Sym(id="cond_2", type=ts.ScalarType(ts.ScalarKind.BOOL)), + itir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ @@ -364,14 +365,7 @@ def test_gtir_select_nested(): a = np.random.rand(N) - arg_types = [ - IFTYPE, - IFTYPE, - ts.ScalarType(ts.ScalarKind.BOOL), - ts.ScalarType(ts.ScalarKind.BOOL), - SIZE_TYPE, - ] - sdfg = dace_backend.build_sdfg_from_gtir(testee, arg_types, {}) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) for s1 in [False, True]: for s2 in [False, True]: From 25b9048391181cf80ef92a5b4bf36b71941590b8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 09:29:16 +0200 Subject: [PATCH 094/235] Support field with start offset --- .../gtir_builtin_translators.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 4b89656161..3e390d3e23 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -15,7 +15,7 @@ import abc from dataclasses import dataclass -from typing import Callable, TypeAlias, final +from typing import Callable, Optional, TypeAlias, final import dace import dace.subsets as sbs @@ -56,7 +56,10 @@ def __call__(self) -> list[SDFGField]: @final def add_local_storage( - self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + self, + data_type: ts.FieldType | ts.ScalarType, + shape: Optional[list[dace.symbolic.SymbolicType]] = None, + offset: Optional[list[dace.symbolic.SymbolicType]] = None, ) -> dace.nodes.AccessNode: """ Allocates temporary storage to be used in the local scope for intermediate results. @@ -64,11 +67,14 @@ def add_local_storage( The storage is allocate with unique names to enable SSA optimization in the compilation phase. """ if isinstance(data_type, ts.FieldType): + assert shape assert len(data_type.dims) == len(shape) dtype = dace_fieldview_util.as_dace_type(data_type.dtype) - name, _ = self.sdfg.add_array("var", shape, dtype, find_new_name=True, transient=True) + name, _ = self.sdfg.add_array( + "var", shape, dtype, offset=offset, find_new_name=True, transient=True + ) else: - assert len(shape) == 0 + assert not shape dtype = dace_fieldview_util.as_dace_type(data_type) name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) return self.head_state.add_access(name) @@ -190,6 +196,9 @@ def build(self) -> list[SDFGField]: (ub - lb) for _, lb, ub in self.field_domain ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(lb != 0 for _, lb, _ in self.field_domain): + field_offset = [lb for _, lb, _ in self.field_domain] if isinstance(output_desc, dace.data.Array): raise NotImplementedError else: @@ -199,13 +208,15 @@ def build(self) -> list[SDFGField]: # TODO: use `self.field_dtype` directly, without passing through `dtype` field_type = ts.FieldType(field_dims, dtype) - field_node = self.add_local_storage(field_type, field_shape) + field_node = self.add_local_storage(field_type, field_shape, field_offset) # assume tasklet with single output output_subset = [ DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain ] if isinstance(output_desc, dace.data.Array): + # additional local dimension for neighbors + assert output_desc.offset is None output_subset.extend(f"0:{size}" for size in output_desc.shape) self.head_state.add_memlet_path( @@ -295,22 +306,21 @@ def build(self) -> list[SDFGField]: assert isinstance(false_br_node, dace.nodes.AccessNode) desc = true_br_node.desc(self.sdfg) assert false_br_node.desc(self.sdfg) == desc - access_node = self.add_local_storage(true_br_type, desc.shape) - output_nodes.append((access_node, true_br_type)) + data_name, _ = self.sdfg.add_temp_transient_like(desc) + output_nodes.append((self.head_state.add_access(data_name), true_br_type)) - data_name = access_node.data true_br_output_node = true_state.add_access(data_name) true_state.add_nedge( true_br_node, true_br_output_node, - dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), + dace.Memlet.from_array(data_name, desc), ) false_br_output_node = false_state.add_access(data_name) false_state.add_nedge( false_br_node, false_br_output_node, - dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), + dace.Memlet.from_array(data_name, desc), ) return output_nodes @@ -348,7 +358,7 @@ def build(self) -> list[SDFGField]: {"__out"}, f"__out = {self.sym_name}", ) - sym_node = self.add_local_storage(self.sym_type, shape=[]) + sym_node = self.add_local_storage(self.sym_type) self.head_state.add_edge( tasklet_node, "__out", From f6e5b7c60bd304d3653bf52b7021b30ee40fabec Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 09:56:23 +0200 Subject: [PATCH 095/235] Add test coverage for temporary with start offset (cartesian shift) --- .../runners_tests/test_dace_fieldview.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 173ce325b1..45921c2bc9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -419,7 +419,11 @@ def test_gtir_cartesian_shift(): DELTA = 3 OFFSET = 1 domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")( + itir.AxisLiteral(value=IDim.value), + 0, + im.minus(itir.SymRef(id="size"), itir.Literal(value=str(OFFSET), type=SIZE_TYPE)), + ) ) # cartesian shift with literal integer offset @@ -518,7 +522,7 @@ def test_gtir_cartesian_shift(): )(), ) - a = np.random.rand(N + OFFSET) + a = np.random.rand(N) a_offset = np.full(N, OFFSET, dtype=np.int32) b = np.empty(N) @@ -558,10 +562,9 @@ def test_gtir_cartesian_shift(): sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) FSYMBOLS_tmp = FSYMBOLS.copy() - FSYMBOLS_tmp["__x_size_0"] = N + OFFSET FSYMBOLS_tmp["__x_offset_stride_0"] = 1 sdfg(x=a, x_offset=a_offset, y=b, **FSYMBOLS_tmp) - assert np.allclose(a[OFFSET:] + DELTA, b) + assert np.allclose(a[OFFSET:] + DELTA, b[:-OFFSET]) def test_gtir_connectivity_shift(): From d7312fa53264b8c6afe62f744ad6df26bfbbe657 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 09:29:16 +0200 Subject: [PATCH 096/235] Support field with start offset --- .../gtir_builtin_translators.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 4b89656161..3f8eae61ed 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -15,7 +15,7 @@ import abc from dataclasses import dataclass -from typing import Callable, TypeAlias, final +from typing import Callable, Optional, TypeAlias, final import dace import dace.subsets as sbs @@ -56,7 +56,10 @@ def __call__(self) -> list[SDFGField]: @final def add_local_storage( - self, data_type: ts.FieldType | ts.ScalarType, shape: list[str] + self, + data_type: ts.FieldType | ts.ScalarType, + shape: Optional[list[dace.symbolic.SymbolicType]] = None, + offset: Optional[list[dace.symbolic.SymbolicType]] = None, ) -> dace.nodes.AccessNode: """ Allocates temporary storage to be used in the local scope for intermediate results. @@ -64,11 +67,14 @@ def add_local_storage( The storage is allocate with unique names to enable SSA optimization in the compilation phase. """ if isinstance(data_type, ts.FieldType): + assert shape assert len(data_type.dims) == len(shape) dtype = dace_fieldview_util.as_dace_type(data_type.dtype) - name, _ = self.sdfg.add_array("var", shape, dtype, find_new_name=True, transient=True) + name, _ = self.sdfg.add_array( + "var", shape, dtype, offset=offset, find_new_name=True, transient=True + ) else: - assert len(shape) == 0 + assert not shape dtype = dace_fieldview_util.as_dace_type(data_type) name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) return self.head_state.add_access(name) @@ -190,6 +196,9 @@ def build(self) -> list[SDFGField]: (ub - lb) for _, lb, ub in self.field_domain ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(lb != 0 for _, lb, _ in self.field_domain): + field_offset = [lb for _, lb, _ in self.field_domain] if isinstance(output_desc, dace.data.Array): raise NotImplementedError else: @@ -199,13 +208,15 @@ def build(self) -> list[SDFGField]: # TODO: use `self.field_dtype` directly, without passing through `dtype` field_type = ts.FieldType(field_dims, dtype) - field_node = self.add_local_storage(field_type, field_shape) + field_node = self.add_local_storage(field_type, field_shape, field_offset) # assume tasklet with single output output_subset = [ DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain ] if isinstance(output_desc, dace.data.Array): + # additional local dimension for neighbors + assert set(output_desc.offset) == {0} output_subset.extend(f"0:{size}" for size in output_desc.shape) self.head_state.add_memlet_path( @@ -295,22 +306,21 @@ def build(self) -> list[SDFGField]: assert isinstance(false_br_node, dace.nodes.AccessNode) desc = true_br_node.desc(self.sdfg) assert false_br_node.desc(self.sdfg) == desc - access_node = self.add_local_storage(true_br_type, desc.shape) - output_nodes.append((access_node, true_br_type)) + data_name, _ = self.sdfg.add_temp_transient_like(desc) + output_nodes.append((self.head_state.add_access(data_name), true_br_type)) - data_name = access_node.data true_br_output_node = true_state.add_access(data_name) true_state.add_nedge( true_br_node, true_br_output_node, - dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), + dace.Memlet.from_array(data_name, desc), ) false_br_output_node = false_state.add_access(data_name) false_state.add_nedge( false_br_node, false_br_output_node, - dace.Memlet.from_array(data_name, access_node.desc(self.sdfg)), + dace.Memlet.from_array(data_name, desc), ) return output_nodes @@ -348,7 +358,7 @@ def build(self) -> list[SDFGField]: {"__out"}, f"__out = {self.sym_name}", ) - sym_node = self.add_local_storage(self.sym_type, shape=[]) + sym_node = self.add_local_storage(self.sym_type) self.head_state.add_edge( tasklet_node, "__out", From c4f273873d00fbcbd2a9248b8aac63f1e05dde54 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 13:37:39 +0200 Subject: [PATCH 097/235] Test IR updated for literal operand --- .../runners/dace_fieldview/gtir_builtin_translators.py | 6 ++++-- .../runners/dace_fieldview/gtir_to_sdfg.py | 5 +++++ .../runners_tests/test_dace_fieldview.py | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3f8eae61ed..7d2b5ceb2e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -351,9 +351,11 @@ def build(self) -> list[SDFGField]: else: # scalar symbols are passed to the SDFG as symbols: build tasklet node # to write the symbol to a scalar access node - assert self.sym_name in self.sdfg.symbols + tasklet_name = ( + f"get_{self.sym_name}" if self.sym_name in self.sdfg.symbols else "get_value" + ) tasklet_node = self.head_state.add_tasklet( - f"get_{self.sym_name}", + tasklet_name, {}, {"__out"}, f"__out = {self.sym_name}", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 172fc63835..cb7e191141 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -253,6 +253,11 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: """ raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.") + def visit_Literal( + self, node: itir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState + ) -> gtir_builtin_translators.SDFGFieldBuilder: + return gtir_builtin_translators.SymbolRef(sdfg, head_state, node.value, node.type) + def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> gtir_builtin_translators.SDFGFieldBuilder: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index f830d68da1..ed8a54426d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -97,10 +97,10 @@ def test_gtir_update(): itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("x"), + )("x", 1.0), domain=domain, target=itir.SymRef(id="x"), ) From 0fd0b657888689a2b4a97ca6bc3ad77cf82ab10d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 13:45:32 +0200 Subject: [PATCH 098/235] Add test coverage to previous commit --- .../runners_tests/test_dace_fieldview.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index ed8a54426d..03f71e08fb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -85,35 +85,43 @@ def test_gtir_update(): domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( - id="gtir_update", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )("x", 1.0), - domain=domain, - target=itir.SymRef(id="x"), - ) - ], - ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + domain, + ) + )("x") + stencil2 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )("x", 1.0) - a = np.random.rand(N) - ref = a + 1.0 + for i, stencil in enumerate([stencil1, stencil2]): + testee = itir.Program( + id=f"gtir_update_{i}", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil, + domain=domain, + target=itir.SymRef(id="x"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + a = np.random.rand(N) + ref = a + 1.0 - sdfg(x=a, **FSYMBOLS) - assert np.allclose(a, ref) + sdfg(x=a, **FSYMBOLS) + assert np.allclose(a, ref) def test_gtir_sum2(): @@ -334,25 +342,25 @@ def test_gtir_select_nested(): im.deref("cond_1"), im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("x"), + )("x", 1), im.call( im.call("select")( im.deref("cond_2"), im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 2)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("x"), + )("x", 2), im.call( im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 3)), + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), domain, ) - )("x"), + )("x", 3), ) )(), ) From 38d2720822bafb0f8c098b2cd337f143e41469d0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 4 Jul 2024 16:33:38 +0200 Subject: [PATCH 099/235] Refactor PrimitiveTranslator interface --- .../gtir_builtin_translators.py | 238 +++++++----------- .../runners/dace_fieldview/gtir_to_sdfg.py | 42 +--- .../runners/dace_fieldview/sdfg_builder.py | 29 +++ 3 files changed, 130 insertions(+), 179 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 7d2b5ceb2e..14b6f4080a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -15,25 +15,24 @@ import abc from dataclasses import dataclass -from typing import Callable, Optional, TypeAlias, final +from typing import Optional, TypeAlias import dace import dace.subsets as sbs -from gt4py import eve -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_to_tasklet, utility as dace_fieldview_util, ) +from gt4py.next.program_processors.runners.dace_fieldview.sdfg_builder import SDFGBuilder from gt4py.next.type_system import type_specifications as ts # Define aliases for return types SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] -SDFGFieldBuilder: TypeAlias = Callable[[], list[SDFGField]] DIMENSION_INDEX_FMT = "i_{dim}" ITERATOR_INDEX_DTYPE = dace.int32 # type of iterator indexes @@ -41,46 +40,14 @@ @dataclass(frozen=True) class PrimitiveTranslator(abc.ABC): - sdfg: dace.SDFG - head_state: dace.SDFGState - - @final - def __call__(self) -> list[SDFGField]: - """The callable interface is used to build the dataflow graph. - - It allows to build the dataflow graph inside a given state starting - from the innermost nodes, by propagating the intermediate results - as access nodes to temporary local storage. - """ - return self.build() - - @final - def add_local_storage( - self, - data_type: ts.FieldType | ts.ScalarType, - shape: Optional[list[dace.symbolic.SymbolicType]] = None, - offset: Optional[list[dace.symbolic.SymbolicType]] = None, - ) -> dace.nodes.AccessNode: - """ - Allocates temporary storage to be used in the local scope for intermediate results. - - The storage is allocate with unique names to enable SSA optimization in the compilation phase. - """ - if isinstance(data_type, ts.FieldType): - assert shape - assert len(data_type.dims) == len(shape) - dtype = dace_fieldview_util.as_dace_type(data_type.dtype) - name, _ = self.sdfg.add_array( - "var", shape, dtype, offset=offset, find_new_name=True, transient=True - ) - else: - assert not shape - dtype = dace_fieldview_util.as_dace_type(data_type) - name, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) - return self.head_state.add_access(name) - @abc.abstractmethod - def build(self) -> list[SDFGField]: + def __call__( + self, + node: itir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: SDFGBuilder, + ) -> list[SDFGField]: # keep only call, same interface for all primitives """Creates the dataflow subgraph representing a GTIR builtin function. This method is used by derived classes to build a specialized subgraph @@ -96,28 +63,21 @@ def build(self) -> list[SDFGField]: class AsFieldOp(PrimitiveTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] + callable_args: list[itir.Expr] - stencil_expr: itir.Lambda - stencil_args: list[SDFGFieldBuilder] - field_domain: list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] - field_dtype: ts.ScalarType - offset_provider: dict[str, Connectivity | Dimension] + def __init__(self, node_args: list[itir.Expr]): + self.callable_args = node_args - def __init__( + def __call__( self, + node: itir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - node: itir.FunCall, - stencil_args: list[SDFGFieldBuilder], - offset_provider: dict[str, Connectivity | Dimension], - ): - super().__init__(sdfg, state) - self.offset_provider = offset_provider - - assert cpm.is_call_to(node.fun, "as_fieldop") - assert len(node.fun.args) == 2 - stencil_expr, domain_expr = node.fun.args + sdfg_builder: SDFGBuilder, + ) -> list[SDFGField]: + assert cpm.is_call_to(node, "as_fieldop") + assert len(node.args) == 2 + stencil_expr, domain_expr = node.args # expect stencil (represented as a lambda function) as first argument assert isinstance(stencil_expr, itir.Lambda) # the domain of the field operator is passed as second argument @@ -126,19 +86,14 @@ def __init__( # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + field_domain = dace_fieldview_util.get_field_domain(domain_expr) - self.field_domain = dace_fieldview_util.get_field_domain(domain_expr) - self.field_dtype = node_type - self.stencil_expr = stencil_expr - self.stencil_args = stencil_args - - def build(self) -> list[SDFGField]: # first visit the list of arguments and build a symbol map stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in self.stencil_args: - arg_nodes = arg() - assert len(arg_nodes) == 1 - data_node, arg_type = arg_nodes[0] + for arg in self.callable_args: + fields: list[SDFGField] = sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) + assert len(fields) == 1 + data_node, arg_type = fields[0] # require all argument nodes to be data access nodes (no symbols) assert isinstance(data_node, dace.nodes.AccessNode) @@ -152,7 +107,7 @@ def build(self) -> list[SDFGField]: dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim.value)), ITERATOR_INDEX_DTYPE, ) - for dim, _, _ in self.field_domain + for dim, _, _ in field_domain } iterator_arg = gtir_to_tasklet.IteratorExpr( data_node, @@ -163,63 +118,61 @@ def build(self) -> list[SDFGField]: # create map range corresponding to the field operator domain map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" - for dim, lb, ub in self.field_domain + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain } - me, mx = self.head_state.add_map("field_op", map_ranges) + me, mx = state.add_map("field_op", map_ranges) # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet( - self.sdfg, self.head_state, me, self.offset_provider - ) - output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, me, sdfg_builder.offset_provider) + output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - output_desc = output_expr.node.desc(self.sdfg) + output_desc = output_expr.node.desc(sdfg) # retrieve the tasklet node which writes the result - last_node = self.head_state.in_edges(output_expr.node)[0].src + last_node = state.in_edges(output_expr.node)[0].src if isinstance(last_node, dace.nodes.Tasklet): # the last transient node can be deleted - last_node_connector = self.head_state.in_edges(output_expr.node)[0].src_conn - self.head_state.remove_node(output_expr.node) + last_node_connector = state.in_edges(output_expr.node)[0].src_conn + state.remove_node(output_expr.node) if len(last_node.in_connectors) == 0: # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - self.head_state.add_nedge(me, last_node, dace.Memlet()) + state.add_nedge(me, last_node, dace.Memlet()) else: last_node = output_expr.node last_node_connector = None # allocate local temporary storage for the result field - field_dims = [dim for dim, _, _ in self.field_domain] + field_dims = [dim for dim, _, _ in field_domain] field_shape = [ # diff between upper and lower bound (ub - lb) - for _, lb, ub in self.field_domain + for _, lb, ub in field_domain ] field_offset: Optional[list[dace.symbolic.SymbolicType]] = None - if any(lb != 0 for _, lb, _ in self.field_domain): - field_offset = [lb for _, lb, _ in self.field_domain] + if any(lb != 0 for _, lb, _ in field_domain): + field_offset = [lb for _, lb, _ in field_domain] if isinstance(output_desc, dace.data.Array): raise NotImplementedError else: assert isinstance(output_expr.field_type, ts.ScalarType) - # TODO: enable `assert output_expr.field_type == self.field_dtype`, remove variable `dtype` - dtype = output_expr.field_type + # TODO: enable `assert output_expr.field_type == node_type`, remove variable `dtype` + node_type = output_expr.field_type - # TODO: use `self.field_dtype` directly, without passing through `dtype` - field_type = ts.FieldType(field_dims, dtype) - field_node = self.add_local_storage(field_type, field_shape, field_offset) + # TODO: use `field_type` directly, without passing through `dtype` + field_type = ts.FieldType(field_dims, node_type) + temp_name, _ = sdfg.add_temp_transient( + field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + ) + field_node = state.add_access(temp_name) # assume tasklet with single output - output_subset = [ - DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in self.field_domain - ] + output_subset = [DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in field_domain] if isinstance(output_desc, dace.data.Array): # additional local dimension for neighbors assert set(output_desc.offset) == {0} output_subset.extend(f"0:{size}" for size in output_desc.shape) - self.head_state.add_memlet_path( + state.add_memlet_path( last_node, mx, field_node, @@ -233,21 +186,16 @@ def build(self) -> list[SDFGField]: class Select(PrimitiveTranslator): """Generates the dataflow subgraph for the `select` builtin function.""" - true_br_builder: SDFGFieldBuilder - false_br_builder: SDFGFieldBuilder - - def __init__( + def __call__( self, + node: itir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - dataflow_builder: eve.NodeVisitor, - node: itir.FunCall, - ): - super().__init__(sdfg, state) - - assert cpm.is_call_to(node.fun, "select") - assert len(node.fun.args) == 3 - cond_expr, true_expr, false_expr = node.fun.args + sdfg_builder: SDFGBuilder, + ) -> list[SDFGField]: + assert cpm.is_call_to(node, "select") + assert len(node.args) == 3 + cond_expr, true_expr, false_expr = node.args # expect condition as first argument cond = dace_fieldview_util.get_symbolic_expr(cond_expr) @@ -274,7 +222,6 @@ def __init__( true_state = sdfg.add_state(state.label + "_true_branch") sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) sdfg.add_edge(true_state, state, dace.InterstateEdge()) - self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") @@ -282,21 +229,9 @@ def __init__( select_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")) ) sdfg.add_edge(false_state, state, dace.InterstateEdge()) - self.false_br_builder = dataflow_builder.visit( - false_expr, sdfg=sdfg, head_state=false_state - ) - def build(self) -> list[SDFGField]: - # retrieve true/false states as predecessors of head state - branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) - assert len(branch_states) == 2 - if branch_states[0].label.endswith("_true_branch"): - true_state, false_state = branch_states - else: - false_state, true_state = branch_states - - true_br_args = self.true_br_builder() - false_br_args = self.false_br_builder() + true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) + false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) output_nodes = [] for true_br, false_br in zip(true_br_args, false_br_args, strict=True): @@ -304,10 +239,10 @@ def build(self) -> list[SDFGField]: assert isinstance(true_br_node, dace.nodes.AccessNode) false_br_node, _ = false_br assert isinstance(false_br_node, dace.nodes.AccessNode) - desc = true_br_node.desc(self.sdfg) - assert false_br_node.desc(self.sdfg) == desc - data_name, _ = self.sdfg.add_temp_transient_like(desc) - output_nodes.append((self.head_state.add_access(data_name), true_br_type)) + desc = true_br_node.desc(sdfg) + assert false_br_node.desc(sdfg) == desc + data_name, _ = sdfg.add_temp_transient_like(desc) + output_nodes.append((state.add_access(data_name), true_br_type)) true_br_output_node = true_state.add_access(data_name) true_state.add_nedge( @@ -329,39 +264,44 @@ def build(self) -> list[SDFGField]: class SymbolRef(PrimitiveTranslator): """Generates the dataflow subgraph for a `ir.SymRef` node.""" - sym_name: str - sym_type: ts.FieldType | ts.ScalarType - - def __init__( + def __call__( self, + node: itir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sym_name: str, - sym_type: ts.FieldType | ts.ScalarType, - ): - super().__init__(sdfg, state) - self.sym_name = sym_name - self.sym_type = sym_type - - def build(self) -> list[SDFGField]: - if isinstance(self.sym_type, ts.FieldType): + sdfg_builder: SDFGBuilder, + ) -> list[SDFGField]: + assert isinstance(node, (itir.Literal, itir.SymRef)) + + data_type: ts.FieldType | ts.ScalarType + if isinstance(node, itir.Literal): + sym_value = node.value + data_type = node.type + tasklet_name = "get_literal" + else: + sym_value = str(node.id) + assert sym_value in sdfg_builder.symbol_types + data_type = sdfg_builder.symbol_types[sym_value] + tasklet_name = f"get_{sym_value}" + + if isinstance(data_type, ts.FieldType): # add access node to current state - sym_node = self.head_state.add_access(self.sym_name) + sym_node = state.add_access(sym_value) else: # scalar symbols are passed to the SDFG as symbols: build tasklet node # to write the symbol to a scalar access node - tasklet_name = ( - f"get_{self.sym_name}" if self.sym_name in self.sdfg.symbols else "get_value" - ) - tasklet_node = self.head_state.add_tasklet( + tasklet_node = state.add_tasklet( tasklet_name, {}, {"__out"}, - f"__out = {self.sym_name}", + f"__out = {sym_value}", + ) + temp_name, _ = sdfg.add_temp_transient( + (1,), dace_fieldview_util.as_dace_type(data_type) ) - sym_node = self.add_local_storage(self.sym_type) - self.head_state.add_edge( + sym_node = state.add_access(temp_name) + state.add_edge( tasklet_node, "__out", sym_node, @@ -369,4 +309,4 @@ def build(self) -> list[SDFGField]: dace.Memlet(data=sym_node.data, subset="0"), ) - return [(sym_node, self.sym_type)] + return [(sym_node, data_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index cb7e191141..1122f244f6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -21,7 +21,6 @@ import dace -from gt4py import eve from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -29,10 +28,11 @@ gtir_builtin_translators, utility as dace_fieldview_util, ) +from gt4py.next.program_processors.runners.dace_fieldview.sdfg_builder import SDFGBuilder from gt4py.next.type_system import type_specifications as ts -class GTIRToSDFG(eve.NodeVisitor): +class GTIRToSDFG(SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation @@ -45,15 +45,11 @@ class GTIRToSDFG(eve.NodeVisitor): from where to continue building the SDFG. """ - offset_provider: dict[str, Connectivity | Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] - def __init__( self, offset_provider: dict[str, Connectivity | Dimension], ): - self.offset_provider = offset_provider - self.symbol_types = {} + super().__init__(offset_provider, symbol_types={}) def _make_array_shape_and_strides( self, name: str, dims: Sequence[Dimension] @@ -127,10 +123,9 @@ def _visit_expression( in case the transient arrays containing the expression result are not guaranteed to have the same memory layout as the target array. """ - field_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( + results: list[gtir_builtin_translators.SDFGField] = self.visit( node, sdfg=sdfg, head_state=head_state ) - results = field_builder() field_nodes = [] for node, _ in results: @@ -225,23 +220,13 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) def visit_FunCall( self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> gtir_builtin_translators.SDFGFieldBuilder: - # first visit the argument nodes - arg_builders = [] - for arg in node.args: - arg_builder: gtir_builtin_translators.SDFGFieldBuilder = self.visit( - arg, sdfg=sdfg, head_state=head_state - ) - arg_builders.append(arg_builder) - + ) -> list[gtir_builtin_translators.SDFGField]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.AsFieldOp( - sdfg, head_state, node, arg_builders, self.offset_provider - ) + return gtir_builtin_translators.AsFieldOp(node.args)(node.fun, sdfg, head_state, self) elif cpm.is_call_to(node.fun, "select"): - assert len(arg_builders) == 0 - return gtir_builtin_translators.Select(sdfg, head_state, self, node) + assert len(node.args) == 0 + return gtir_builtin_translators.Select()(node.fun, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -255,13 +240,10 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: def visit_Literal( self, node: itir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> gtir_builtin_translators.SDFGFieldBuilder: - return gtir_builtin_translators.SymbolRef(sdfg, head_state, node.value, node.type) + ) -> list[gtir_builtin_translators.SDFGField]: + return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) def visit_SymRef( self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> gtir_builtin_translators.SDFGFieldBuilder: - symbol_name = str(node.id) - assert symbol_name in self.symbol_types - symbol_type = self.symbol_types[symbol_name] - return gtir_builtin_translators.SymbolRef(sdfg, head_state, symbol_name, symbol_type) + ) -> list[gtir_builtin_translators.SDFGField]: + return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py new file mode 100644 index 0000000000..ed5a48fbba --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py @@ -0,0 +1,29 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Visitor interface to build an SDFG dataflow. + +""" + +from dataclasses import dataclass + +from gt4py import eve +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.type_system import type_specifications as ts + + +@dataclass(frozen=True) +class SDFGBuilder(eve.NodeVisitor): + offset_provider: dict[str, Connectivity | Dimension] + symbol_types: dict[str, ts.FieldType | ts.ScalarType] \ No newline at end of file From d3541c15957e251bf6bb16a5e8dc31099616da4b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 08:38:53 +0200 Subject: [PATCH 100/235] Made a small modfication to some code. It is now possible to write `plus(1.0, deref("it"))` befor one had to write `plus(dref("it"), 1.0)`. --- .../runners/dace_fieldview/gtir_to_tasklet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index a662e52dfb..6c7c0bad36 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -720,8 +720,14 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value ) # TODO: use type inference to determine the result type - if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): - dtype = node_connections["__inp_0"].node.desc(self.sdfg).dtype + if len(node_connections) == 1: + dtype = None + for conn_name in ["__inp_0", "__inp_1"]: + if conn_name in node_connections: + dtype = node_connections[conn_name].node.desc(self.sdfg).dtype + break + if dtype is None: + raise ValueError("Failed to dtermine the type") else: node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) dtype = dace_fieldview_util.as_dace_type(node_type) From e855ef9997fbfb7107ea2b65169f6469e0ad63e2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 09:00:27 +0200 Subject: [PATCH 101/235] Fix formatting --- .../program_processors/runners/dace_fieldview/sdfg_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py index ed5a48fbba..dafddbcd1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py @@ -26,4 +26,4 @@ @dataclass(frozen=True) class SDFGBuilder(eve.NodeVisitor): offset_provider: dict[str, Connectivity | Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] \ No newline at end of file + symbol_types: dict[str, ts.FieldType | ts.ScalarType] From 572650983c6f6122a232a1043a9abc516de10700 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 09:19:37 +0200 Subject: [PATCH 102/235] Started with a first nabla stuff. --- my_playground/nambla4.py | 185 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 my_playground/nambla4.py diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py new file mode 100644 index 0000000000..362d31c527 --- /dev/null +++ b/my_playground/nambla4.py @@ -0,0 +1,185 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +""" +Implementation of the Nabla4 Stencil. +""" + +import copy + +from gt4py.next.common import NeighborTable +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.ffront.fbuiltins import Field +from gt4py.next.program_processors.runners import dace_fieldview as dace_backend +from gt4py.next.type_system import type_specifications as ts +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + KDim, + Cell, + Edge, + IDim, + JDim, + MeshDescriptor, + V2EDim, + Vertex, + simple_mesh, + skip_value_mesh, +) +from typing import Sequence, Any +from functools import reduce +import numpy as np + +import dace + +wpfloat = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) + + +def nabla4_np( + N: Field[[Edge, KDim], wpfloat], + T: Field[[Edge, KDim], wpfloat], + **kwargs, # Allows to use the same call argument object as for the SDFG +) -> Field[[Edge, KDim], wpfloat]: + return N + T + + +def dace_strides( + array: np.ndarray, + name: None | str = None, +) -> tuple[int, ...] | dict[str, int]: + if not hasattr(array, "strides"): + return {} + strides = array.strides + if hasattr(array, "itemsize"): + strides = tuple(stride // array.itemsize for stride in strides) + if name is not None: + strides = {f"__{name}_stride_{i}": stride for i, stride in enumerate(strides)} + return strides + + +def dace_shape( + array: np.ndarray, + name: str, +) -> dict[str, int]: + if not hasattr(array, "shape"): + return {} + return {f"__{name}_size_{i}": size for i, size in enumerate(array.shape)} + + +def make_syms(**kwargs: np.ndarray) -> dict[str, int]: + SYMBS = {} + for name, array in kwargs.items(): + SYMBS.update(**dace_shape(array, name)) + SYMBS.update(**dace_strides(array, name)) + return SYMBS + + +def build_nambla4_gtir(): + edge_k_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "num_edges"), + im.call("named_range")(itir.AxisLiteral(value=KDim.value), 0, "num_k_levels"), + ) + edge_k_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "num_edges"), + im.call("named_range")(itir.AxisLiteral(value=KDim.value), 0, "num_k_levels"), + ) + + num_edges = 27 + num_k_levels = 10 + + nabla4prog = itir.Program( + id="nabla4_partial", + function_definitions=[], + params=[ + itir.Sym(id="T"), + itir.Sym(id="N"), + itir.Sym(id="nab4"), + itir.Sym(id="num_edges"), + itir.Sym(id="num_k_levels"), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("N_it", "T_it")(im.plus(im.deref("T_it"), im.deref("N_it"))), + edge_k_domain, + ) + )("N", "T"), + domain=edge_k_domain, + target=itir.SymRef(id="nab4"), + ) + ], + ) + + EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) + + arg_types = [ + EK_FTYPE, + EK_FTYPE, + EK_FTYPE, + SIZE_TYPE, + SIZE_TYPE, + ] + offset_provider = {} + + N = np.random.rand(num_edges, num_k_levels) + T = np.random.rand(num_k_levels) + nab4 = np.empty_like(N) + + sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, arg_types, offset_provider) + + call_args = dict( + T=T, + N=N, + nab4=nab4, + num_edges=num_edges, + num_k_levels=num_k_levels, + ) + SYMBS = make_syms(**call_args) + + sdfg(**call_args, **SYMBS) + ref = nabla4_np(**call_args) + + assert np.allclose(ref, nab4) + + +if "__main__" == __name__: + build_nambla4_gtir() + + """ + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("NpT")( + im.multiplies_(im.deref("NpT"), 4.0) + ), + edge_k_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("N", "T")( + im.plus(im.deref("N"), im.deref("T")) + ), + edge_k_domain, + ) + )("N", "T") + ), + domain=edge_k_domain, + target=itir.SymRef(id="nab4"), + ) + ], + """ From e44f3a258a6aca87f84e81daf17f601f491cee23 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 09:20:56 +0200 Subject: [PATCH 103/235] It seems that local storage does not work well with this transformer. regardless of our experiences in the last transformation stuff. --- .../runners/dace_fieldview/transformations/auto_opt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 5bab359bef..9a7cd5da33 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -37,10 +37,9 @@ def dace_auto_optimize( kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. """ from dace.transformation.auto.auto_optimize import auto_optimize as _auto_optimize - from dace.transformation.dataflow import InLocalStorage, OutLocalStorage # Now put output storages everywhere to make auto optimizer less likely to fail. - sdfg.apply_transformations_repeated([InLocalStorage, OutLocalStorage]) + # sdfg.apply_transformations_repeated([InLocalStorage, OutLocalStorage]) # noqa: ERA001 [commented-out-code] # Now the optimization. sdfg = _auto_optimize(sdfg, device=device, **kwargs) From 4cff0711f6fc4c7be07c18fe68358a420824bdda Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 09:42:25 +0200 Subject: [PATCH 104/235] Fix for domain horzontal/vertical dims --- .../dace_fieldview/gtir_builtin_translators.py | 2 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 5 ++--- .../runners/dace_fieldview/utility.py | 18 +++++++----------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 14b6f4080a..6f3731a4af 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -86,7 +86,7 @@ def __call__( # add local storage to compute the field operator over the given domain # TODO: use type inference to determine the result type node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - field_domain = dace_fieldview_util.get_field_domain(domain_expr) + field_domain = dace_fieldview_util.get_domain(domain_expr) # first visit the list of arguments and build a symbol map stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 1122f244f6..7cd5c43f8a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -196,7 +196,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) target_nodes = self._visit_expression(stmt.target, sdfg, state) # convert domain expression to dictionary to ease access to dimension boundaries - domain = dace_fieldview_util.get_domain(stmt.domain) + domain = dace_fieldview_util.get_domain_ranges(stmt.domain) for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] @@ -205,8 +205,7 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) if isinstance(target_symbol_type, ts.FieldType): subset = ",".join( - f"{domain[dim.value][0]}:{domain[dim.value][1]}" - for dim in target_symbol_type.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target_symbol_type.dims ) else: assert len(domain) == 0 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index b9152f1fb3..9a20d06487 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,7 @@ import dace -from gt4py.next.common import Connectivity, Dimension, DimensionKind +from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet @@ -78,7 +78,7 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne } -def get_field_domain( +def get_domain( node: itir.Expr, ) -> list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ @@ -99,25 +99,21 @@ def get_field_domain( sym_str = get_symbolic_expr(arg) sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) - size_value = str(bounds[1] - bounds[0]) - if size_value.isdigit(): - dim = Dimension(axis.value, DimensionKind.LOCAL) - else: - dim = Dimension(axis.value, DimensionKind.HORIZONTAL) + dim = Dimension(axis.value, axis.kind) domain.append((dim, bounds[0], bounds[1])) return domain -def get_domain( +def get_domain_ranges( node: itir.Expr, -) -> dict[str, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: +) -> dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Returns domain represented in dictionary form. """ - field_domain = get_field_domain(node) + domain = get_domain(node) - return {dim.value: (lb, ub) for dim, lb, ub in field_domain} + return {dim: (lb, ub) for dim, lb, ub in domain} def get_symbolic_expr(node: itir.Expr) -> str: From f642e8576312679c97cdb5855b855d4d482be2ec Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 09:45:43 +0200 Subject: [PATCH 105/235] Fix for type inference on single value expression --- .../runners/dace_fieldview/gtir_to_tasklet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 42d4e4feae..e8ffa61145 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -261,8 +261,14 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value ) # TODO: use type inference to determine the result type - if len(node_connections) == 1 and isinstance(node_connections["__inp_0"], MemletExpr): - dtype = node_connections["__inp_0"].node.desc(self.sdfg).dtype + if len(node_connections) == 1: + dtype = None + for conn_name in ["__inp_0", "__inp_1"]: + if conn_name in node_connections: + dtype = node_connections[conn_name].node.desc(self.sdfg).dtype + break + if dtype is None: + raise ValueError("Failed to determine the type") else: node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) dtype = dace_fieldview_util.as_dace_type(node_type) From 74bd468199c07e4d201334bae8737f74fb1d73f1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 09:58:52 +0200 Subject: [PATCH 106/235] Updated it now seems to work. --- my_playground/nambla4.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index 362d31c527..848646646f 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -87,26 +87,26 @@ def make_syms(**kwargs: np.ndarray) -> dict[str, int]: def build_nambla4_gtir(): edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "num_edges"), - im.call("named_range")(itir.AxisLiteral(value=KDim.value), 0, "num_k_levels"), - ) - edge_k_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=Edge.value), 0, "num_edges"), - im.call("named_range")(itir.AxisLiteral(value=KDim.value), 0, "num_k_levels"), + im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), + im.call("named_range")( + itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" + ), ) num_edges = 27 num_k_levels = 10 + EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) + nabla4prog = itir.Program( id="nabla4_partial", function_definitions=[], params=[ - itir.Sym(id="T"), - itir.Sym(id="N"), - itir.Sym(id="nab4"), - itir.Sym(id="num_edges"), - itir.Sym(id="num_k_levels"), + itir.Sym(id="T", type=EK_FTYPE), + itir.Sym(id="N", type=EK_FTYPE), + itir.Sym(id="nab4", type=EK_FTYPE), + itir.Sym(id="num_edges", type=SIZE_TYPE), + itir.Sym(id="num_k_levels", type=SIZE_TYPE), ], declarations=[], body=[ @@ -123,22 +123,13 @@ def build_nambla4_gtir(): ], ) - EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) - - arg_types = [ - EK_FTYPE, - EK_FTYPE, - EK_FTYPE, - SIZE_TYPE, - SIZE_TYPE, - ] offset_provider = {} N = np.random.rand(num_edges, num_k_levels) - T = np.random.rand(num_k_levels) + T = np.random.rand(num_edges, num_k_levels) nab4 = np.empty_like(N) - sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, arg_types, offset_provider) + sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) call_args = dict( T=T, @@ -153,6 +144,7 @@ def build_nambla4_gtir(): ref = nabla4_np(**call_args) assert np.allclose(ref, nab4) + print(f"Test succeeded") if "__main__" == __name__: From 667eb7e3ec4ff12da612ca09cdfb0e59de41aec5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 10:40:36 +0200 Subject: [PATCH 107/235] Updated the nabla4 calculations. --- my_playground/nambla4.py | 130 +++++++++++++++++++++++++++------------ 1 file changed, 92 insertions(+), 38 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index 848646646f..898ca89cba 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -47,11 +47,15 @@ def nabla4_np( - N: Field[[Edge, KDim], wpfloat], - T: Field[[Edge, KDim], wpfloat], + nabv_norm: Field[[Edge, KDim], wpfloat], + nabv_tang: Field[[Edge, KDim], wpfloat], + z_nabla2_e: Field[[Edge, KDim], wpfloat], + **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[Edge, KDim], wpfloat]: - return N + T + N = nabv_norm - 2 * z_nabla2_e + T = nabv_tang - 2 * z_nabla2_e + return 4 * (N + T) def dace_strides( @@ -102,8 +106,10 @@ def build_nambla4_gtir(): id="nabla4_partial", function_definitions=[], params=[ - itir.Sym(id="T", type=EK_FTYPE), - itir.Sym(id="N", type=EK_FTYPE), + itir.Sym(id="nabv_norm", type=EK_FTYPE), + itir.Sym(id="nabv_tang", type=EK_FTYPE), + itir.Sym(id="z_nabla2_e", type=EK_FTYPE), + itir.Sym(id="nab4", type=EK_FTYPE), itir.Sym(id="num_edges", type=SIZE_TYPE), itir.Sym(id="num_k_levels", type=SIZE_TYPE), @@ -113,27 +119,101 @@ def build_nambla4_gtir(): itir.SetAt( expr=im.call( im.call("as_fieldop")( - im.lambda_("N_it", "T_it")(im.plus(im.deref("T_it"), im.deref("N_it"))), + im.lambda_("NpT", "const_4")( + im.multiplies_(im.deref("NpT"), im.deref("const_4")) + ), edge_k_domain, ) - )("N", "T"), + )( + # arg: `NpT` + im.call( + im.call("as_fieldop")( + im.lambda_("N", "T")( + im.plus(im.deref("N"), im.deref("T")) + ), + edge_k_domain, + ) + )( + # arg: `N` + im.call( + im.call("as_fieldop")( + im.lambda_("xn", "z_nabla2_e2")( + im.minus(im.deref("xn"), im.deref("z_nabla2_e2")) + ), + edge_k_domain, + ) + )( + # arg: `xn` + "nabv_norm", + + # arg: `z_nabla2_e2` + im.call( + im.call("as_fieldop")( + im.lambda_("z_nabla2_e", "const_2")( + im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + ), + edge_k_domain, + ) + )( + # arg: `z_nabla2_e` + "z_nabla2_e", + # arg: `const_2` + 2.0 + ), + ), + + # arg: `T` + im.call( + im.call("as_fieldop")( + im.lambda_("xt", "z_nabla2_e2")( + im.minus(im.deref("xt"), im.deref("z_nabla2_e2")) + ), + edge_k_domain + ) + )( + # arg: `xt` + "nabv_tang", + + # arg: `z_nabla2_e2` + im.call( + im.call("as_fieldop")( + im.lambda_("z_nabla2_e", "const_2")( + im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + ), + edge_k_domain, + ) + )( + # arg: `z_nabla2_e` + "z_nabla2_e", + # arg: `const_2` + 2.0 + ), + ), + ), + + # arg: `const_4` + 4.0, + ), domain=edge_k_domain, target=itir.SymRef(id="nab4"), ) + ], ) offset_provider = {} - N = np.random.rand(num_edges, num_k_levels) - T = np.random.rand(num_edges, num_k_levels) - nab4 = np.empty_like(N) + nabv_norm = np.random.rand(num_edges, num_k_levels) + nabv_tang = np.random.rand(num_edges, num_k_levels) + z_nabla2_e = np.random.rand(num_edges, num_k_levels) + nab4 = np.empty((num_edges, num_k_levels), dtype=nabv_norm.dtype) sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) call_args = dict( - T=T, - N=N, + nabv_norm=nabv_norm, + nabv_tang=nabv_tang, + z_nabla2_e=z_nabla2_e, nab4=nab4, num_edges=num_edges, num_k_levels=num_k_levels, @@ -149,29 +229,3 @@ def build_nambla4_gtir(): if "__main__" == __name__: build_nambla4_gtir() - - """ - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("NpT")( - im.multiplies_(im.deref("NpT"), 4.0) - ), - edge_k_domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("N", "T")( - im.plus(im.deref("N"), im.deref("T")) - ), - edge_k_domain, - ) - )("N", "T") - ), - domain=edge_k_domain, - target=itir.SymRef(id="nab4"), - ) - ], - """ From 58b8e58a757f349cde0bde36b3ec489c47795615 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 11:27:52 +0200 Subject: [PATCH 108/235] Now all the calculations are done. --- my_playground/nambla4.py | 130 ++++++++++++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 30 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index 898ca89cba..e29788d338 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -50,12 +50,22 @@ def nabla4_np( nabv_norm: Field[[Edge, KDim], wpfloat], nabv_tang: Field[[Edge, KDim], wpfloat], z_nabla2_e: Field[[Edge, KDim], wpfloat], + inv_vert_vert_length: Field[[Edge], wpfloat], + + inv_primal_edge_length: Field[[Edge], wpfloat], **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[Edge, KDim], wpfloat]: + N = nabv_norm - 2 * z_nabla2_e + ell_v2 = inv_vert_vert_length ** 2 + N_ellv2 = N * ell_v2.reshape((-1, 1)) + T = nabv_tang - 2 * z_nabla2_e - return 4 * (N + T) + ell_e2 = inv_primal_edge_length ** 2 + T_elle2 = T * ell_e2.reshape((-1, 1)) + + return 4 * (N_ellv2 + T_elle2) def dace_strides( @@ -101,6 +111,7 @@ def build_nambla4_gtir(): num_k_levels = 10 EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) + E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) nabla4prog = itir.Program( id="nabla4_partial", @@ -109,7 +120,8 @@ def build_nambla4_gtir(): itir.Sym(id="nabv_norm", type=EK_FTYPE), itir.Sym(id="nabv_tang", type=EK_FTYPE), itir.Sym(id="z_nabla2_e", type=EK_FTYPE), - + itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), + itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), itir.Sym(id="nab4", type=EK_FTYPE), itir.Sym(id="num_edges", type=SIZE_TYPE), itir.Sym(id="num_k_levels", type=SIZE_TYPE), @@ -128,68 +140,122 @@ def build_nambla4_gtir(): # arg: `NpT` im.call( im.call("as_fieldop")( - im.lambda_("N", "T")( - im.plus(im.deref("N"), im.deref("T")) + im.lambda_("N_ell2", "T_ell2")( + im.plus(im.deref("N_ell2"), im.deref("T_ell2")) ), edge_k_domain, ) )( - # arg: `N` + # arg: `N_ell2` im.call( im.call("as_fieldop")( - im.lambda_("xn", "z_nabla2_e2")( - im.minus(im.deref("xn"), im.deref("z_nabla2_e2")) + im.lambda_("ell_v2", "N")( + im.multiplies_(im.deref("N"), im.deref("ell_v2")) ), edge_k_domain, ) )( - # arg: `xn` - "nabv_norm", + # arg: `ell_v2` + im.call( + im.call("as_fieldop")( + im.lambda_("ell_v")( + im.multiplies_(im.deref("ell_v"), im.deref("ell_v")) + ), + edge_k_domain, + ) + )( + # arg: `ell_v` + "inv_vert_vert_length" + ), + # end arg: `ell_v2` - # arg: `z_nabla2_e2` + # arg: `N` im.call( im.call("as_fieldop")( - im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + im.lambda_("xn", "z_nabla2_e2")( + im.minus(im.deref("xn"), im.deref("z_nabla2_e2")) ), edge_k_domain, ) )( - # arg: `z_nabla2_e` - "z_nabla2_e", - # arg: `const_2` - 2.0 + # arg: `xn` + "nabv_norm", + + # arg: `z_nabla2_e2` + im.call( + im.call("as_fieldop")( + im.lambda_("z_nabla2_e", "const_2")( + im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + ), + edge_k_domain, + ) + )( + # arg: `z_nabla2_e` + "z_nabla2_e", + # arg: `const_2` + 2.0 + ), + # end arg: `z_nabla2_e2` ), + # end arg: `N` ), + # end arg: `N_ell2` - # arg: `T` + # arg: `T_ell2` im.call( im.call("as_fieldop")( - im.lambda_("xt", "z_nabla2_e2")( - im.minus(im.deref("xt"), im.deref("z_nabla2_e2")) + im.lambda_("ell_e2", "T")( + im.multiplies_(im.deref("T"), im.deref("ell_e2")) ), - edge_k_domain + edge_k_domain, ) )( - # arg: `xt` - "nabv_tang", - - # arg: `z_nabla2_e2` + # arg: `ell_e2` im.call( im.call("as_fieldop")( - im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + im.lambda_("ell_e")( + im.multiplies_(im.deref("ell_e"), im.deref("ell_e")) ), edge_k_domain, ) )( - # arg: `z_nabla2_e` - "z_nabla2_e", - # arg: `const_2` - 2.0 + # arg: `ell_e` + "inv_primal_edge_length" + ), + # end arg: `ell_e2` + + # arg: `T` + im.call( + im.call("as_fieldop")( + im.lambda_("xt", "z_nabla2_e2")( + im.minus(im.deref("xt"), im.deref("z_nabla2_e2")) + ), + edge_k_domain + ) + )( + # arg: `xt` + "nabv_tang", + + # arg: `z_nabla2_e2` + im.call( + im.call("as_fieldop")( + im.lambda_("z_nabla2_e", "const_2")( + im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + ), + edge_k_domain, + ) + )( + # arg: `z_nabla2_e` + "z_nabla2_e", + # arg: `const_2` + 2.0 + ), ), + # end arg: `T` ), + # end arg: `T_ell2` ), + # end arg: `NpT` # arg: `const_4` 4.0, @@ -206,6 +272,8 @@ def build_nambla4_gtir(): nabv_norm = np.random.rand(num_edges, num_k_levels) nabv_tang = np.random.rand(num_edges, num_k_levels) z_nabla2_e = np.random.rand(num_edges, num_k_levels) + inv_vert_vert_length = np.random.rand(num_edges) + inv_primal_edge_length = np.random.rand(num_edges) nab4 = np.empty((num_edges, num_k_levels), dtype=nabv_norm.dtype) sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) @@ -214,6 +282,8 @@ def build_nambla4_gtir(): nabv_norm=nabv_norm, nabv_tang=nabv_tang, z_nabla2_e=z_nabla2_e, + inv_vert_vert_length=inv_vert_vert_length, + inv_primal_edge_length=inv_primal_edge_length, nab4=nab4, num_edges=num_edges, num_k_levels=num_k_levels, From e898b31bde6411669ea29483e5f6528fc6c950ce Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 11:30:08 +0200 Subject: [PATCH 109/235] Formated a bit. --- my_playground/nambla4.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index e29788d338..7bf8bc4490 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -51,18 +51,15 @@ def nabla4_np( nabv_tang: Field[[Edge, KDim], wpfloat], z_nabla2_e: Field[[Edge, KDim], wpfloat], inv_vert_vert_length: Field[[Edge], wpfloat], - inv_primal_edge_length: Field[[Edge], wpfloat], - **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[Edge, KDim], wpfloat]: - N = nabv_norm - 2 * z_nabla2_e - ell_v2 = inv_vert_vert_length ** 2 + ell_v2 = inv_vert_vert_length**2 N_ellv2 = N * ell_v2.reshape((-1, 1)) T = nabv_tang - 2 * z_nabla2_e - ell_e2 = inv_primal_edge_length ** 2 + ell_e2 = inv_primal_edge_length**2 T_elle2 = T * ell_e2.reshape((-1, 1)) return 4 * (N_ellv2 + T_elle2) @@ -168,7 +165,6 @@ def build_nambla4_gtir(): "inv_vert_vert_length" ), # end arg: `ell_v2` - # arg: `N` im.call( im.call("as_fieldop")( @@ -180,27 +176,27 @@ def build_nambla4_gtir(): )( # arg: `xn` "nabv_norm", - # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + im.multiplies_( + im.deref("z_nabla2_e"), im.deref("const_2") + ) ), edge_k_domain, ) )( # arg: `z_nabla2_e` - "z_nabla2_e", + "z_nabla2_e", # arg: `const_2` - 2.0 + 2.0, ), # end arg: `z_nabla2_e2` ), # end arg: `N` ), # end arg: `N_ell2` - # arg: `T_ell2` im.call( im.call("as_fieldop")( @@ -223,32 +219,32 @@ def build_nambla4_gtir(): "inv_primal_edge_length" ), # end arg: `ell_e2` - # arg: `T` im.call( im.call("as_fieldop")( im.lambda_("xt", "z_nabla2_e2")( im.minus(im.deref("xt"), im.deref("z_nabla2_e2")) ), - edge_k_domain + edge_k_domain, ) )( # arg: `xt` "nabv_tang", - # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_(im.deref("z_nabla2_e"), im.deref("const_2")) + im.multiplies_( + im.deref("z_nabla2_e"), im.deref("const_2") + ) ), edge_k_domain, ) )( # arg: `z_nabla2_e` - "z_nabla2_e", + "z_nabla2_e", # arg: `const_2` - 2.0 + 2.0, ), ), # end arg: `T` @@ -256,14 +252,12 @@ def build_nambla4_gtir(): # end arg: `T_ell2` ), # end arg: `NpT` - # arg: `const_4` 4.0, ), domain=edge_k_domain, target=itir.SymRef(id="nab4"), ) - ], ) From eae968f824d312257dffae8f3ff75240f0b07846 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 5 Jul 2024 11:36:53 +0200 Subject: [PATCH 110/235] Refactored the code. --- my_playground/nambla4.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index 7bf8bc4490..8c52dcc0b7 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -96,7 +96,10 @@ def make_syms(**kwargs: np.ndarray) -> dict[str, int]: return SYMBS -def build_nambla4_gtir(): +def build_nambla4_gtir_fieldview( + num_edges: int, + num_k_levels: int, +) -> itir.Program: edge_k_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), im.call("named_range")( @@ -104,9 +107,6 @@ def build_nambla4_gtir(): ), ) - num_edges = 27 - num_k_levels = 10 - EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) @@ -261,6 +261,29 @@ def build_nambla4_gtir(): ], ) + return nabla4prog + + + +def verify_nabla4( + version: str, +): + num_edges = 27 + num_k_levels = 10 + + if(version == "fieldview"): + nabla4prog = build_nambla4_gtir_fieldview( + num_edges=num_edges, + num_k_levels=num_k_levels, + ) + + elif(version == "inline"): + raise NotImplementedError("`inline` version is not yet implemented.") + + else: + raise ValueError(f"The version `{version}` is now known.") + + offset_provider = {} nabv_norm = np.random.rand(num_edges, num_k_levels) @@ -288,8 +311,8 @@ def build_nambla4_gtir(): ref = nabla4_np(**call_args) assert np.allclose(ref, nab4) - print(f"Test succeeded") + print(f"Version({version}): Succeeded") if "__main__" == __name__: - build_nambla4_gtir() + verify_nabla4("fieldview") From defb55d554e51a9a66b48a76b4f47336954c4648 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 18:36:08 +0200 Subject: [PATCH 111/235] Import changes from dace-fieldview-neighbors --- .../gtir_builtin_translators.py | 32 ++++++--- .../runners/dace_fieldview/gtir_to_sdfg.py | 2 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 71 +++++++++---------- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 6f3731a4af..69ae5f33d6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -116,15 +116,9 @@ def __call__( ) stencil_args.append(iterator_arg) - # create map range corresponding to the field operator domain - map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain - } - me, mx = state.add_map("field_op", map_ranges) - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, me, sdfg_builder.offset_provider) - output_expr = taskgen.visit(stencil_expr, args=stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.offset_provider) + input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -134,9 +128,6 @@ def __call__( # the last transient node can be deleted last_node_connector = state.in_edges(output_expr.node)[0].src_conn state.remove_node(output_expr.node) - if len(last_node.in_connectors) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - state.add_nedge(me, last_node, dace.Memlet()) else: last_node = output_expr.node last_node_connector = None @@ -172,6 +163,25 @@ def __call__( assert set(output_desc.offset) == {0} output_subset.extend(f"0:{size}" for size in output_desc.shape) + # create map range corresponding to the field operator domain + map_ranges = { + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain + } + me, mx = state.add_map("field_op", map_ranges) + + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + state.add_nedge(me, last_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) state.add_memlet_path( last_node, mx, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 130e3eb723..8a2b1bf8a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -186,7 +186,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: # We store all connectivity tables as transient arrays here; later, while building # the field operator expressions, we change to non transient the tables # that are actually used. This way, we avoid adding SDFG arguments for - # the connectivity tabkes that are not used. + # the connectivity tables that are not used. self._add_storage( sdfg, dace_fieldview_util.connectivity_identifier(offset), type_, transient=True ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index f3e68716fb..ef86ed65ce 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -55,6 +55,14 @@ class ValueExpr: field_type: ts.FieldType | ts.ScalarType +# Define alias for the elements needed to setup input connections to a map scope +InputConnection: TypeAlias = tuple[ + dace.nodes.AccessNode, + sbs.Range, + dace.nodes.Node, + Optional[str], +] + IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @@ -136,36 +144,29 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState offset_provider: dict[str, Connectivity | Dimension] + input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - map_entry: dace.nodes.MapEntry, offset_provider: dict[str, Connectivity | Dimension], ): self.sdfg = sdfg self.state = state - self.map_entry = map_entry self.offset_provider = offset_provider + self.input_connections = [] self.symbol_map = {} def _add_entry_memlet_path( self, - *path_nodes: dace.nodes.Node, - memlet: Optional[dace.Memlet] = None, - src_conn: Optional[str] = None, + src: dace.nodes.AccessNode, + src_subset: sbs.Range, + dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, ) -> None: - self.state.add_memlet_path( - path_nodes[0], - self.map_entry, - *path_nodes[1:], - memlet=memlet, - src_conn=src_conn, - dst_conn=dst_conn, - ) + self.input_connections.append((src, src_subset, dst_node, dst_conn)) def _get_tasklet_result( self, @@ -222,9 +223,9 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: # add new termination point for this field parameter self._add_entry_memlet_path( it.field, + sbs.Range.from_array(field_desc), deref_node, - dst_conn="field", - memlet=dace.Memlet.from_array(it.field.data, field_desc), + "field", ) for dim, index_expr in field_indices: @@ -232,9 +233,9 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: if isinstance(index_expr, MemletExpr): self._add_entry_memlet_path( index_expr.node, + index_expr.subset, deref_node, - dst_conn=deref_connector, - memlet=dace.Memlet(data=index_expr.node.data, subset=index_expr.subset), + deref_connector, ) elif isinstance(index_expr, ValueExpr): @@ -315,9 +316,9 @@ def _make_cartesian_shift( dtype = input_expr.node.desc(self.sdfg).dtype self._add_entry_memlet_path( input_expr.node, + input_expr.subset, dynamic_offset_tasklet, - dst_conn=input_connector, - memlet=dace.Memlet(data=input_expr.node.data, subset=input_expr.subset), + input_connector, ) elif isinstance(input_expr, ValueExpr): if input_connector == "index": @@ -364,18 +365,16 @@ def _make_dynamic_neighbor_offset( ) self._add_entry_memlet_path( offset_table_node, + sbs.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, - dst_conn="table", - memlet=dace.Memlet.from_array( - offset_table_node.data, offset_table_node.desc(self.sdfg) - ), + "table", ) if isinstance(offset_expr, MemletExpr): self._add_entry_memlet_path( offset_expr.node, + offset_expr.subset, tasklet_node, - dst_conn="offset", - memlet=dace.Memlet(data=offset_expr.node.data, subset=offset_expr.subset), + "offset", ) else: self.state.add_edge( @@ -520,9 +519,9 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value else: self._add_entry_memlet_path( arg_expr.node, + arg_expr.subset, tasklet_node, - dst_conn=connector, - memlet=dace.Memlet(data=arg_expr.node.data, subset=arg_expr.subset), + connector, ) # TODO: use type inference to determine the result type @@ -542,32 +541,30 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value def visit_Lambda( self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> ValueExpr: + ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): - return output_expr + return self.input_connections, output_expr if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - dtype = self.sdfg.arrays[output_expr.node.data].dtype - scalar_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - var, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) - result_node = self.state.add_access(var) + output_dtype = output_expr.node.desc(self.sdfg).dtype + tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_entry_memlet_path( output_expr.node, - result_node, - memlet=dace.Memlet(data=output_expr.node.data, subset=output_expr.subset), + output_expr.subset, + tasklet_node, + "__inp", ) - return ValueExpr(result_node, scalar_type) else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype tasklet_node = self.state.add_tasklet( "write", {}, {"__out"}, f"__out = {output_expr.value}" ) - return self._get_tasklet_result(output_dtype, tasklet_node, "__out") + return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) From fc9661caf0952ec458a8609d3016bd7a0cfa6682 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 18:39:16 +0200 Subject: [PATCH 112/235] Import changes from dace-fieldview-shifts --- .../gtir_builtin_translators.py | 32 +++++++----- .../runners/dace_fieldview/gtir_to_tasklet.py | 49 +++++++++---------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 6f3731a4af..69ae5f33d6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -116,15 +116,9 @@ def __call__( ) stencil_args.append(iterator_arg) - # create map range corresponding to the field operator domain - map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain - } - me, mx = state.add_map("field_op", map_ranges) - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, me, sdfg_builder.offset_provider) - output_expr = taskgen.visit(stencil_expr, args=stencil_args) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.offset_provider) + input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -134,9 +128,6 @@ def __call__( # the last transient node can be deleted last_node_connector = state.in_edges(output_expr.node)[0].src_conn state.remove_node(output_expr.node) - if len(last_node.in_connectors) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - state.add_nedge(me, last_node, dace.Memlet()) else: last_node = output_expr.node last_node_connector = None @@ -172,6 +163,25 @@ def __call__( assert set(output_desc.offset) == {0} output_subset.extend(f"0:{size}" for size in output_desc.shape) + # create map range corresponding to the field operator domain + map_ranges = { + DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain + } + me, mx = state.add_map("field_op", map_ranges) + + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + state.add_nedge(me, last_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, + ) state.add_memlet_path( last_node, mx, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index e8ffa61145..425ddff057 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -54,6 +54,14 @@ class ValueExpr: field_type: ts.FieldType | ts.ScalarType +# Define alias for the elements needed to setup input connections to a map scope +InputConnection: TypeAlias = tuple[ + dace.nodes.AccessNode, + sbs.Range, + dace.nodes.Node, + Optional[str], +] + IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr @@ -135,36 +143,29 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState offset_provider: dict[str, Connectivity | Dimension] + input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - map_entry: dace.nodes.MapEntry, offset_provider: dict[str, Connectivity | Dimension], ): self.sdfg = sdfg self.state = state - self.map_entry = map_entry self.offset_provider = offset_provider + self.input_connections = [] self.symbol_map = {} def _add_entry_memlet_path( self, - *path_nodes: dace.nodes.Node, - memlet: Optional[dace.Memlet] = None, - src_conn: Optional[str] = None, + src: dace.nodes.AccessNode, + src_subset: sbs.Range, + dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, ) -> None: - self.state.add_memlet_path( - path_nodes[0], - self.map_entry, - *path_nodes[1:], - memlet=memlet, - src_conn=src_conn, - dst_conn=dst_conn, - ) + self.input_connections.append((src, src_subset, dst_node, dst_conn)) def _get_tasklet_result( self, @@ -255,9 +256,9 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value else: self._add_entry_memlet_path( arg_expr.node, + arg_expr.subset, tasklet_node, - dst_conn=connector, - memlet=dace.Memlet(data=arg_expr.node.data, subset=arg_expr.subset), + connector, ) # TODO: use type inference to determine the result type @@ -277,32 +278,30 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value def visit_Lambda( self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> ValueExpr: + ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr) if isinstance(output_expr, ValueExpr): - return output_expr + return self.input_connections, output_expr if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - dtype = self.sdfg.arrays[output_expr.node.data].dtype - scalar_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - var, _ = self.sdfg.add_scalar("var", dtype, find_new_name=True, transient=True) - result_node = self.state.add_access(var) + output_dtype = output_expr.node.desc(self.sdfg).dtype + tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_entry_memlet_path( output_expr.node, - result_node, - memlet=dace.Memlet(data=output_expr.node.data, subset=output_expr.subset), + output_expr.subset, + tasklet_node, + "__inp", ) - return ValueExpr(result_node, scalar_type) else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype tasklet_node = self.state.add_tasklet( "write", {}, {"__out"}, f"__out = {output_expr.value}" ) - return self._get_tasklet_result(output_dtype, tasklet_node, "__out") + return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: itir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) From e424d4ea2c1edb3543f4e8e22de9858a8304fbe3 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 5 Jul 2024 18:50:31 +0200 Subject: [PATCH 113/235] Minor edit --- .../runners/dace_fieldview/gtir_to_tasklet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 425ddff057..524553d174 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -173,18 +173,18 @@ def _get_tasklet_result( src_node: dace.nodes.Tasklet, src_connector: str, ) -> ValueExpr: - var_name, _ = self.sdfg.add_scalar("var", dtype, transient=True, find_new_name=True) - var_subset = "0" + temp_name = self.sdfg.temp_data_name() + self.sdfg.add_scalar(temp_name, dtype, transient=True) data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype())) - var_node = self.state.add_access(var_name) + temp_node = self.state.add_access(temp_name) self.state.add_edge( src_node, src_connector, - var_node, + temp_node, None, - dace.Memlet(data=var_node.data, subset=var_subset), + dace.Memlet(data=temp_name, subset="0"), ) - return ValueExpr(var_node, data_type) + return ValueExpr(temp_node, data_type) def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 From 563ee1a1fe14a98f1ec6babf18e810f66d335c81 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Sun, 7 Jul 2024 12:11:35 +0200 Subject: [PATCH 114/235] WIP: Working on accessing. --- my_playground/nambla4.py | 122 ++++++++++++++++-- .../transformations/__init__.py | 3 +- .../transformations/auto_opt.py | 38 +++++- 3 files changed, 149 insertions(+), 14 deletions(-) diff --git a/my_playground/nambla4.py b/my_playground/nambla4.py index 8c52dcc0b7..c14265d860 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nambla4.py @@ -96,10 +96,108 @@ def make_syms(**kwargs: np.ndarray) -> dict[str, int]: return SYMBS +def build_nambla4_gtir_inline( + num_edges: int, + num_k_levels: int, +) -> itir.Program: + """Creates the `nabla4` stencil where the computations are already inlined.""" + + edge_k_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), + im.call("named_range")( + itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" + ), + ) + + EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) + E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) + + nabla4prog = itir.Program( + id="nabla4_partial_inline", + function_definitions=[], + params=[ + itir.Sym(id="nabv_norm", type=EK_FTYPE), + itir.Sym(id="nabv_tang", type=EK_FTYPE), + itir.Sym(id="z_nabla2_e", type=EK_FTYPE), + itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), + itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), + itir.Sym(id="nab4", type=EK_FTYPE), + itir.Sym(id="num_edges", type=SIZE_TYPE), + itir.Sym(id="num_k_levels", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_( + "z_nabla2_e2", + "const_4", + "nabv_norm", + "inv_vert_vert_length", + "nabv_tang", + "inv_primal_edge_length", + )( + im.multiplies_( + im.plus( + im.multiplies_( + im.minus( + im.deref("nabv_norm"), + im.deref("z_nabla2_e2"), + ), + im.multiplies_( + im.deref("inv_vert_vert_length"), + im.deref("inv_vert_vert_length"), + ), + ), + im.multiplies_( + im.minus( + im.deref("nabv_tang"), + im.deref("z_nabla2_e2"), + ), + im.multiplies_( + im.deref("inv_primal_edge_length"), + im.deref("inv_primal_edge_length"), + ), + ), + ), + im.deref("const_4"), + ) + ), + edge_k_domain, + ) + )( + # arg: `z_nabla2_e2` + im.call( + im.call("as_fieldop")( + im.lambda_("x", "const_2")( + im.multiplies_(im.deref("x"), im.deref("const_2")) + ), + edge_k_domain, + ) + )("z_nabla2_e", 2.0), + # end arg: `z_nabla2_e2` + # arg: `const_4` + 4.0, + # Same name as in the lambda and are argument to the program. + "nabv_norm", + "inv_vert_vert_length", + "nabv_tang", + "inv_primal_edge_length", + ), + domain=edge_k_domain, + target=itir.SymRef(id="nab4"), + ) + ], + ) + return nabla4prog + + def build_nambla4_gtir_fieldview( - num_edges: int, - num_k_levels: int, + num_edges: int, + num_k_levels: int, ) -> itir.Program: + """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" edge_k_domain = im.call("unstructured_domain")( im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), im.call("named_range")( @@ -111,7 +209,7 @@ def build_nambla4_gtir_fieldview( E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) nabla4prog = itir.Program( - id="nabla4_partial", + id="nabla4_partial_fieldview", function_definitions=[], params=[ itir.Sym(id="nabv_norm", type=EK_FTYPE), @@ -264,26 +362,27 @@ def build_nambla4_gtir_fieldview( return nabla4prog - def verify_nabla4( - version: str, + version: str, ): num_edges = 27 num_k_levels = 10 - if(version == "fieldview"): + if version == "fieldview": nabla4prog = build_nambla4_gtir_fieldview( - num_edges=num_edges, - num_k_levels=num_k_levels, + num_edges=num_edges, + num_k_levels=num_k_levels, ) - elif(version == "inline"): - raise NotImplementedError("`inline` version is not yet implemented.") + elif version == "inline": + nabla4prog = build_nambla4_gtir_inline( + num_edges=num_edges, + num_k_levels=num_k_levels, + ) else: raise ValueError(f"The version `{version}` is now known.") - offset_provider = {} nabv_norm = np.random.rand(num_edges, num_k_levels) @@ -315,4 +414,5 @@ def verify_nabla4( if "__main__" == __name__: + #verify_nabla4("inline") verify_nabla4("fieldview") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4090950e5f..46fe340027 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,9 +12,10 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from .auto_opt import dace_auto_optimize +from .auto_opt import dace_auto_optimize, gt_auto_optimize __all__ = [ "dace_auto_optimize", + "gt_auto_optimize", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 9a7cd5da33..09d2132e9f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -17,6 +17,8 @@ from typing import Any import dace +from dace.transformation.auto import auto_optimize as dace_aoptimize +from dace.transformation import dataflow as dace_dataflow def dace_auto_optimize( @@ -36,16 +38,48 @@ def dace_auto_optimize( device: the device for which optimizations should be done, defaults to CPU. kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. """ - from dace.transformation.auto.auto_optimize import auto_optimize as _auto_optimize # Now put output storages everywhere to make auto optimizer less likely to fail. # sdfg.apply_transformations_repeated([InLocalStorage, OutLocalStorage]) # noqa: ERA001 [commented-out-code] # Now the optimization. - sdfg = _auto_optimize(sdfg, device=device, **kwargs) + sdfg = dace_aoptimize(sdfg, device=device, **kwargs) # Now the simplification step. # This should get rid of some of teh additional transients we have added. sdfg.simplify() + return sdfg + + +def gt_auto_optimize( + sdfg: dace.SDFG, + device: dace.DeviceType = dace.DeviceType.CPU, + **kwargs: Any, +) -> dace.SDFG: + """Performs GT4Py specific optimizations in place. + + Args: + sdfg: The SDFG that should ve optimized in place. + device: The device for which we should optimize. + """ + + # Initial cleaning + sdfg.simplify() + + + + # Due to the structure of the generated SDFG getting rid of Maps, + # i.e. fusing them, is the best we can currently do. + sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) + + # These are the part that we copy from DaCe built in auto optimization. + dace_aoptimize.set_fast_implementations(sdfg, device) + dace_aoptimize.make_transients_persistent(sdfg, device) + dace_aoptimize.move_small_arrays_to_stack(sdfg) + + sdfg.simplify() + + return sdfg + From f32fd38b72590563ea367b932609f925777beadf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 10:05:37 +0200 Subject: [PATCH 115/235] Now the shift works, at least the shift in the particular dimension. Now it is time to add it to the real nabla stuff. --- my_playground/my_stuff.py | 236 ++++++++++++++++++++++++++++++ my_playground/simple_icon_mesh.py | 125 ++++++++++++++++ 2 files changed, 361 insertions(+) create mode 100644 my_playground/my_stuff.py create mode 100644 my_playground/simple_icon_mesh.py diff --git a/my_playground/my_stuff.py b/my_playground/my_stuff.py new file mode 100644 index 0000000000..3728d3b1bc --- /dev/null +++ b/my_playground/my_stuff.py @@ -0,0 +1,236 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Test that ITIR can be lowered to SDFG. + +Note: this test module covers the fieldview flavour of ITIR. +""" + +import copy +from gt4py.next.common import NeighborTable +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.runners import dace_fieldview as dace_backend +from gt4py.next.type_system import type_specifications as ts +from functools import reduce +import numpy as np + +from simple_icon_mesh import ( + IDim, # Dimensions + JDim, + KDim, + EdgeDim, + VertexDim, + CellDim, + ECVDim, + E2C2VDim, + NbCells, # Constants of the size + NbEdges, + NbVertices, + E2C2VDim, # Offsets + E2C2V, + SIZE_TYPE, # Type definitions + E2C2V_connectivity, + E2ECV_connectivity, + make_syms, # Helpers +) + +# For cartesian stuff. +N = 10 +IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +IJFTYPE = ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + + +###################### +# TESTS + + +def gtir_copy3(): + # We can not use the size symbols inside the domain + # Because the translator complains. + + # Input domain + input_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "org_sizeI"), + im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "org_sizeJ"), + ) + + # Domain for after we have processed the IDim. + first_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "sizeI"), + im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "org_sizeJ"), + ) + + # This is the final domain, or after we have removed the JDim + final_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "sizeI"), + im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "sizeJ"), + ) + + IOffset = 1 + JOffset = 2 + + testee = itir.Program( + id="gtir_copy", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IJFTYPE), + itir.Sym(id="y", type=IJFTYPE), + itir.Sym(id="sizeI", type=SIZE_TYPE), + itir.Sym(id="sizeJ", type=SIZE_TYPE), + itir.Sym(id="org_sizeI", type=SIZE_TYPE), + itir.Sym(id="org_sizeJ", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + # This processed the `JDim`, it is first because its arguments are + # evaluated before, so there the first cutting is happening. + im.call("as_fieldop")( + im.lambda_("a")( + im.deref( + im.shift("JDim", JOffset)("a") # This does not work + ) + ), + final_domain, + ) + )( + # Now here we will process the `IDim` part. + im.call( + im.call("as_fieldop")( + im.lambda_("b")( + im.deref( + im.shift("IDim", IOffset)("b") + # "b" + ) + ), + first_domain, + ) + )("x"), + ), + domain=final_domain, + target=itir.SymRef(id="y"), + ) + ], + ) + + # We only need an offset provider for the translation. + offset_provider = { + "IDim": IDim, + "JDim": JDim, + } + + sdfg = dace_backend.build_sdfg_from_gtir( + testee, + offset_provider, + ) + + output_size_I, output_size_J = 10, 10 + input_size_I, input_size_J = 20, 20 + + a = np.random.rand(input_size_I, input_size_J) + b = np.empty((output_size_I, output_size_J), dtype=np.float64) + + SYMBS = make_syms(x=a, y=b) + + sdfg( + x=a, + y=b, + sizeI=output_size_I, + sizeJ=output_size_J, + org_sizeI=input_size_I, + org_sizeJ=input_size_J, + **SYMBS, + ) + + ref = a[IOffset : (IOffset + output_size_I), JOffset : (JOffset + output_size_J)] + + assert np.all(b == ref) + assert True + + +def gtir_ecv_shift(): + # EdgeDim, E2C2VDim + domain = im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "nedges" + ), + # im.call("named_range")(itir.AxisLiteral(value=E2C2VDim.value, kind=E2C2VDim.kind), 0, 4), + ) + + INPUT_FTYPE = ts.FieldType(dims=[ECVDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + OUTPUT_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + + testee = itir.Program( + id="gtir_shift", + function_definitions=[], + params=[ + itir.Sym(id="x", type=INPUT_FTYPE), + itir.Sym(id="y", type=OUTPUT_FTYPE), + itir.Sym(id="nedges", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + # This processed the `JDim`, it is first because its arguments are + # evaluated before, so there the first cutting is happening. + im.call("as_fieldop")( + im.lambda_("a")(im.deref(im.shift("E2ECV", 0)("a"))), + domain, + ) + )("x"), + domain=domain, + target=itir.SymRef(id="y"), + ) + ], + ) + + offset_provider = { + "E2C2V": E2C2V_connectivity, + "E2ECV": E2ECV_connectivity, + } + + sdfg = dace_backend.build_sdfg_from_gtir( + testee, + offset_provider, + ) + + a = np.random.rand(NbEdges * 4) + b = np.empty((NbEdges,), dtype=np.float64) + + call_args = { + "x": a, + "y": b, + "connectivity_E2C2V": E2C2V_connectivity.table.copy(), + "connectivity_E2ECV": E2ECV_connectivity.table.copy(), + } + + SYMBS = make_syms(**call_args) + + sdfg( + **call_args, + nedges=NbEdges, + **SYMBS, + ) + ref = a[E2ECV_connectivity.table[:, 0]] + + assert np.allclose(ref, b) + assert True + + +if "__main__" == __name__: + # gtir_copy3() + gtir_ecv_shift() diff --git a/my_playground/simple_icon_mesh.py b/my_playground/simple_icon_mesh.py new file mode 100644 index 0000000000..de4b8c2a35 --- /dev/null +++ b/my_playground/simple_icon_mesh.py @@ -0,0 +1,125 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Reimplementation of the simple icon grid for testing.""" + +import numpy as np + +from gt4py.next.common import DimensionKind +from gt4py.next.ffront.fbuiltins import Dimension, FieldOffset +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.type_system import type_specifications as ts + +IDim = Dimension("IDim") +JDim = Dimension("JDim") + +KDim = Dimension("K", kind=DimensionKind.VERTICAL) +EdgeDim = Dimension("Edge") +CellDim = Dimension("Cell") +VertexDim = Dimension("Vertex") +ECVDim = Dimension("ECV") +E2C2VDim = Dimension("E2C2V", DimensionKind.LOCAL) + +E2ECV = FieldOffset("E2ECV", source=ECVDim, target=(EdgeDim, E2C2VDim)) +E2C2V = FieldOffset("E2C2V", source=VertexDim, target=(EdgeDim, E2C2VDim)) + +Koff = FieldOffset("Koff", source=KDim, target=(KDim,)) + +NbCells = 18 +NbEdges = 27 +NbVertices = 9 + +SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) + +e2c2v_table = np.asarray( + [ + [0, 1, 4, 6], # 0 + [0, 4, 1, 3], # 1 + [0, 3, 4, 2], # 2 + [1, 2, 5, 7], # 3 + [1, 5, 2, 4], # 4 + [1, 4, 5, 0], # 5 + [2, 0, 3, 8], # 6 + [2, 3, 5, 0], # 7 + [2, 5, 1, 3], # 8 + [3, 4, 0, 7], # 9 + [3, 7, 4, 6], # 10 + [3, 6, 7, 5], # 11 + [4, 5, 8, 1], # 12 + [4, 8, 7, 5], # 13 + [4, 7, 3, 8], # 14 + [5, 3, 6, 2], # 15 + [6, 5, 3, 8], # 16 + [8, 5, 6, 4], # 17 + [6, 7, 3, 1], # 18 + [6, 1, 7, 0], # 19 + [6, 0, 1, 8], # 20 + [7, 8, 2, 4], # 21 + [7, 2, 8, 1], # 22 + [7, 1, 2, 6], # 23 + [8, 6, 0, 5], # 24 + [8, 0, 6, 2], # 25 + [8, 2, 0, 6], # 26 + ] +) + +E2C2V_connectivity = NeighborTableOffsetProvider( + # I do not understand the ordering here? Why is `Edge` the source if you read + # it right to left? + e2c2v_table, + EdgeDim, + VertexDim, + e2c2v_table.shape[1], +) + + +def _make_E2ECV_connectivity(E2C2V_connectivity: NeighborTableOffsetProvider): + # Implementation is adapted from icon's `_get_offset_provider_for_sparse_fields()` + e2c2v_table = E2C2V_connectivity.table + t = np.arange(e2c2v_table.shape[0] * e2c2v_table.shape[1]).reshape(e2c2v_table.shape) + return NeighborTableOffsetProvider(t, EdgeDim, ECVDim, t.shape[1]) + + +E2ECV_connectivity = _make_E2ECV_connectivity(E2C2V_connectivity) + + +def dace_strides( + array: np.ndarray, + name: None | str = None, +) -> tuple[int, ...] | dict[str, int]: + if not hasattr(array, "strides"): + return {} + strides = array.strides + if hasattr(array, "itemsize"): + strides = tuple(stride // array.itemsize for stride in strides) + if name is not None: + strides = {f"__{name}_stride_{i}": stride for i, stride in enumerate(strides)} + return strides + + +def dace_shape( + array: np.ndarray, + name: str, +) -> dict[str, int]: + if not hasattr(array, "shape"): + return {} + return {f"__{name}_size_{i}": size for i, size in enumerate(array.shape)} + + +def make_syms(**kwargs: np.ndarray) -> dict[str, int]: + SYMBS = {} + for name, array in kwargs.items(): + SYMBS.update(**dace_shape(array, name)) + SYMBS.update(**dace_strides(array, name)) + return SYMBS From 538abff8f4f55e07207c405a0e79fa362ae95fc6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 10:07:17 +0200 Subject: [PATCH 116/235] Prepare to go to real input. --- my_playground/{nambla4.py => nabla4.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename my_playground/{nambla4.py => nabla4.py} (99%) diff --git a/my_playground/nambla4.py b/my_playground/nabla4.py similarity index 99% rename from my_playground/nambla4.py rename to my_playground/nabla4.py index c14265d860..e11e66b2c3 100644 --- a/my_playground/nambla4.py +++ b/my_playground/nabla4.py @@ -414,5 +414,5 @@ def verify_nabla4( if "__main__" == __name__: - #verify_nabla4("inline") + verify_nabla4("inline") verify_nabla4("fieldview") From a07fe81b098ae1971f165d7cdf0f19d9e8d80c26 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 10:32:25 +0200 Subject: [PATCH 117/235] nabla4 works now with the custom icon stuff. --- my_playground/nabla4.py | 54 ++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index e11e66b2c3..774798ea91 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -24,18 +24,28 @@ from gt4py.next.ffront.fbuiltins import Field from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - KDim, - Cell, - Edge, - IDim, + + +from simple_icon_mesh import ( + IDim, # Dimensions JDim, - MeshDescriptor, - V2EDim, - Vertex, - simple_mesh, - skip_value_mesh, + KDim, + EdgeDim, + VertexDim, + CellDim, + ECVDim, + E2C2VDim, + NbCells, # Constants of the size + NbEdges, + NbVertices, + E2C2VDim, # Offsets + E2C2V, + SIZE_TYPE, # Type definitions + E2C2V_connectivity, + E2ECV_connectivity, + make_syms, # Helpers ) + from typing import Sequence, Any from functools import reduce import numpy as np @@ -47,13 +57,13 @@ def nabla4_np( - nabv_norm: Field[[Edge, KDim], wpfloat], - nabv_tang: Field[[Edge, KDim], wpfloat], - z_nabla2_e: Field[[Edge, KDim], wpfloat], - inv_vert_vert_length: Field[[Edge], wpfloat], - inv_primal_edge_length: Field[[Edge], wpfloat], + nabv_norm: Field[[EdgeDim, KDim], wpfloat], + nabv_tang: Field[[EdgeDim, KDim], wpfloat], + z_nabla2_e: Field[[EdgeDim, KDim], wpfloat], + inv_vert_vert_length: Field[[EdgeDim], wpfloat], + inv_primal_edge_length: Field[[EdgeDim], wpfloat], **kwargs, # Allows to use the same call argument object as for the SDFG -) -> Field[[Edge, KDim], wpfloat]: +) -> Field[[EdgeDim, KDim], wpfloat]: N = nabv_norm - 2 * z_nabla2_e ell_v2 = inv_vert_vert_length**2 N_ellv2 = N * ell_v2.reshape((-1, 1)) @@ -103,14 +113,14 @@ def build_nambla4_gtir_inline( """Creates the `nabla4` stencil where the computations are already inlined.""" edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), + im.call("named_range")(itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges"), im.call("named_range")( itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" ), ) - EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) - E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) + EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) + E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) nabla4prog = itir.Program( id="nabla4_partial_inline", @@ -199,14 +209,14 @@ def build_nambla4_gtir_fieldview( ) -> itir.Program: """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=Edge.value, kind=Edge.kind), 0, "num_edges"), + im.call("named_range")(itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges"), im.call("named_range")( itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" ), ) - EK_FTYPE = ts.FieldType(dims=[Edge, KDim], dtype=wpfloat) - E_FTYPE = ts.FieldType(dims=[Edge], dtype=wpfloat) + EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) + E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) nabla4prog = itir.Program( id="nabla4_partial_fieldview", From fec054a793f1242d27cb2a5d9c525eec0452ab7a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 10:51:03 +0200 Subject: [PATCH 118/235] First step in shifting. --- my_playground/nabla4.py | 113 +++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 41 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index 774798ea91..62b39a678a 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -25,7 +25,6 @@ from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts - from simple_icon_mesh import ( IDim, # Dimensions JDim, @@ -57,13 +56,20 @@ def nabla4_np( - nabv_norm: Field[[EdgeDim, KDim], wpfloat], + pnvertv1_0: Field[[EdgeDim], wpfloat], + u_vert_0: Field[[EdgeDim, KDim], wpfloat], + + xn_1: Field[[EdgeDim, KDim], wpfloat], + xn_2: Field[[EdgeDim, KDim], wpfloat], + xn_3: Field[[EdgeDim, KDim], wpfloat], nabv_tang: Field[[EdgeDim, KDim], wpfloat], z_nabla2_e: Field[[EdgeDim, KDim], wpfloat], inv_vert_vert_length: Field[[EdgeDim], wpfloat], inv_primal_edge_length: Field[[EdgeDim], wpfloat], **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[EdgeDim, KDim], wpfloat]: + nabv_norm = (u_vert_0 * pnvertv1_0.reshape((-1, 1))) + xn_1 + xn_2 + xn_3 + N = nabv_norm - 2 * z_nabla2_e ell_v2 = inv_vert_vert_length**2 N_ellv2 = N * ell_v2.reshape((-1, 1)) @@ -75,37 +81,6 @@ def nabla4_np( return 4 * (N_ellv2 + T_elle2) -def dace_strides( - array: np.ndarray, - name: None | str = None, -) -> tuple[int, ...] | dict[str, int]: - if not hasattr(array, "strides"): - return {} - strides = array.strides - if hasattr(array, "itemsize"): - strides = tuple(stride // array.itemsize for stride in strides) - if name is not None: - strides = {f"__{name}_stride_{i}": stride for i, stride in enumerate(strides)} - return strides - - -def dace_shape( - array: np.ndarray, - name: str, -) -> dict[str, int]: - if not hasattr(array, "shape"): - return {} - return {f"__{name}_size_{i}": size for i, size in enumerate(array.shape)} - - -def make_syms(**kwargs: np.ndarray) -> dict[str, int]: - SYMBS = {} - for name, array in kwargs.items(): - SYMBS.update(**dace_shape(array, name)) - SYMBS.update(**dace_strides(array, name)) - return SYMBS - - def build_nambla4_gtir_inline( num_edges: int, num_k_levels: int, @@ -113,7 +88,9 @@ def build_nambla4_gtir_inline( """Creates the `nabla4` stencil where the computations are already inlined.""" edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges"), + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" + ), im.call("named_range")( itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" ), @@ -209,20 +186,31 @@ def build_nambla4_gtir_fieldview( ) -> itir.Program: """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges"), + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" + ), im.call("named_range")( itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" ), ) + VK_TYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) + ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=ts.ScalarType(kind=wpfloat)) nabla4prog = itir.Program( id="nabla4_partial_fieldview", function_definitions=[], params=[ - itir.Sym(id="nabv_norm", type=EK_FTYPE), + # itir.Sym(id="u_vert", type=VK_FTYPE), + # itir.Sym(id="v_vert", type=VK_FTYPE), + itir.Sym(id="pnvertv1_0", type=E_FTYPE), + itir.Sym(id="u_vert_0", type=EK_FTYPE), + + itir.Sym(id="xn_1", type=EK_FTYPE), + itir.Sym(id="xn_2", type=EK_FTYPE), + itir.Sym(id="xn_3", type=EK_FTYPE), itir.Sym(id="nabv_tang", type=EK_FTYPE), itir.Sym(id="z_nabla2_e", type=EK_FTYPE), itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), @@ -283,7 +271,42 @@ def build_nambla4_gtir_fieldview( ) )( # arg: `xn` - "nabv_norm", + # u_vert(E2C2V[0]) * primal_normal_vert_v1(E2ECV[0]) || nx_0 + # + v_vert(E2C2V[0]) * primal_normal_vert_v2(E2ECV[0]) || xn_1 + # + u_vert(E2C2V[1]) * primal_normal_vert_v1(E2ECV[1]) || xn_2 + # + v_vert(E2C2V[1]) * primal_normal_vert_v2(E2ECV[1]) || xn_3 + im.call( + im.call("as_fieldop")( + im.lambda_("xn_0", "xn_1", "xn_2", "xn_3")( + im.plus( + im.plus(im.deref("xn_0"), im.deref("xn_1")), + im.plus(im.deref("xn_2"), im.deref("xn_3")), + ) + ), + edge_k_domain, + ) + )( + # arg: `xn_0` + im.call( + im.call("as_fieldop")( + im.lambda_("u_vert_0", "pnvertv1_0")( + im.multiplies_(im.deref("u_vert_0"), im.deref("pnvertv1_0")) + ), + edge_k_domain, + ) + )( + "u_vert_0", + "pnvertv1_0", + ), + # end arg: `xn_0` + + "xn_1", + + "xn_2", + + "xn_3", + ), + # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( @@ -395,17 +418,20 @@ def verify_nabla4( offset_provider = {} - nabv_norm = np.random.rand(num_edges, num_k_levels) + xn_args = {f"xn_{i}": np.random.rand(num_edges, num_k_levels) for i in range(1, 4)} + + pnvertv1_0=np.random.rand(num_edges) + u_vert_0=np.random.rand(num_edges, num_k_levels) + nabv_tang = np.random.rand(num_edges, num_k_levels) z_nabla2_e = np.random.rand(num_edges, num_k_levels) inv_vert_vert_length = np.random.rand(num_edges) inv_primal_edge_length = np.random.rand(num_edges) - nab4 = np.empty((num_edges, num_k_levels), dtype=nabv_norm.dtype) + nab4 = np.empty((num_edges, num_k_levels), dtype=np.float64) sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) call_args = dict( - nabv_norm=nabv_norm, nabv_tang=nabv_tang, z_nabla2_e=z_nabla2_e, inv_vert_vert_length=inv_vert_vert_length, @@ -413,7 +439,12 @@ def verify_nabla4( nab4=nab4, num_edges=num_edges, num_k_levels=num_k_levels, + + u_vert_0=u_vert_0, + pnvertv1_0=pnvertv1_0, ) + call_args.update(xn_args) + SYMBS = make_syms(**call_args) sdfg(**call_args, **SYMBS) @@ -424,5 +455,5 @@ def verify_nabla4( if "__main__" == __name__: - verify_nabla4("inline") + # verify_nabla4("inline") verify_nabla4("fieldview") From ea7bf6425a7085518ad0147560f02b8ed794ecca Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 11:29:57 +0200 Subject: [PATCH 119/235] Now we have one shifting. I think I need a helper function to continue. --- my_playground/nabla4.py | 77 ++++++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index 62b39a678a..aaffb2b186 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -56,8 +56,8 @@ def nabla4_np( - pnvertv1_0: Field[[EdgeDim], wpfloat], - u_vert_0: Field[[EdgeDim, KDim], wpfloat], + u_vert: Field[[EdgeDim, KDim], wpfloat], + primal_normal_vert_v1: Field[[ECVDim], wpfloat], xn_1: Field[[EdgeDim, KDim], wpfloat], xn_2: Field[[EdgeDim, KDim], wpfloat], @@ -66,8 +66,18 @@ def nabla4_np( z_nabla2_e: Field[[EdgeDim, KDim], wpfloat], inv_vert_vert_length: Field[[EdgeDim], wpfloat], inv_primal_edge_length: Field[[EdgeDim], wpfloat], + + # These are the offset providers + E2C2V: NeighborTable, + **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[EdgeDim, KDim], wpfloat]: + + primal_normal_vert_v1 = primal_normal_vert_v1.reshape(E2C2V.table.shape) + + u_vert_0 = u_vert[E2C2V.table[:, 0]] + pnvertv1_0 = primal_normal_vert_v1[:, 0] + nabv_norm = (u_vert_0 * pnvertv1_0.reshape((-1, 1))) + xn_1 + xn_2 + xn_3 N = nabv_norm - 2 * z_nabla2_e @@ -185,6 +195,13 @@ def build_nambla4_gtir_fieldview( num_k_levels: int, ) -> itir.Program: """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" + + + edge_domain = im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" + ), + ) edge_k_domain = im.call("unstructured_domain")( im.call("named_range")( itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" @@ -194,19 +211,18 @@ def build_nambla4_gtir_fieldview( ), ) - VK_TYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) + VK_FTYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) - ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=ts.ScalarType(kind=wpfloat)) + ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) nabla4prog = itir.Program( id="nabla4_partial_fieldview", function_definitions=[], params=[ - # itir.Sym(id="u_vert", type=VK_FTYPE), + itir.Sym(id="u_vert", type=VK_FTYPE), # itir.Sym(id="v_vert", type=VK_FTYPE), - itir.Sym(id="pnvertv1_0", type=E_FTYPE), - itir.Sym(id="u_vert_0", type=EK_FTYPE), + itir.Sym(id="primal_normal_vert_v1", type=ECV_FTYPE), itir.Sym(id="xn_1", type=EK_FTYPE), itir.Sym(id="xn_2", type=EK_FTYPE), @@ -295,8 +311,31 @@ def build_nambla4_gtir_fieldview( edge_k_domain, ) )( - "u_vert_0", - "pnvertv1_0", + # arg: `u_vert_0` + im.call( + im.call("as_fieldop")( + im.lambda_("u_vert_no_shifted")( + im.deref(im.shift("E2C2V", 0)("u_vert_no_shifted")) + ), + edge_k_domain, + ) + )( + "u_vert" # arg: `u_vert_no_shifted` + ), + # end arg: `u_vert_0` + + # arg: `pnvertv1_0` + im.call( + im.call("as_fieldop")( + im.lambda_("primal_normal_vert_v1_no_shifted")( + im.deref(im.shift("E2ECV", 0)("primal_normal_vert_v1_no_shifted")) + ), + edge_domain, + ) + )( + "primal_normal_vert_v1" # arg: `primal_normal_vert_v1_no_shifted` + ), + # end arg: `pnvertv1_0` ), # end arg: `xn_0` @@ -398,7 +437,8 @@ def build_nambla4_gtir_fieldview( def verify_nabla4( version: str, ): - num_edges = 27 + num_edges = NbEdges + num_vertices = NbVertices num_k_levels = 10 if version == "fieldview": @@ -416,12 +456,15 @@ def verify_nabla4( else: raise ValueError(f"The version `{version}` is now known.") - offset_provider = {} + offset_provider = { + "E2C2V": E2C2V_connectivity, + "E2ECV": E2ECV_connectivity, + } xn_args = {f"xn_{i}": np.random.rand(num_edges, num_k_levels) for i in range(1, 4)} - pnvertv1_0=np.random.rand(num_edges) - u_vert_0=np.random.rand(num_edges, num_k_levels) + u_vert = np.random.rand(num_vertices, num_k_levels) + primal_normal_vert_v1 = np.random.rand(num_edges * 4) nabv_tang = np.random.rand(num_edges, num_k_levels) z_nabla2_e = np.random.rand(num_edges, num_k_levels) @@ -440,15 +483,17 @@ def verify_nabla4( num_edges=num_edges, num_k_levels=num_k_levels, - u_vert_0=u_vert_0, - pnvertv1_0=pnvertv1_0, + u_vert=u_vert, + primal_normal_vert_v1=primal_normal_vert_v1, ) call_args.update(xn_args) + call_args.update({f"connectivity_{k}": v.table.copy() for k, v in offset_provider.items()}) + SYMBS = make_syms(**call_args) sdfg(**call_args, **SYMBS) - ref = nabla4_np(**call_args) + ref = nabla4_np(**call_args, **offset_provider) assert np.allclose(ref, nab4) print(f"Version({version}): Succeeded") From b291152ac5bf38cc589607be47e68006bf9b096b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 11:56:37 +0200 Subject: [PATCH 120/235] The helper function works. --- my_playground/nabla4.py | 84 ++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index aaffb2b186..f9e793766f 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -216,6 +216,55 @@ def build_nambla4_gtir_fieldview( E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) + + def shift_builder( + vert: str, + vert_idx: int, + primal: str, + primal_idx: int, + ) -> itir.FunCall: + """Used to construct the shifting calculations. + + This function generates the IR for the expression: + ``` + vert[E2C2V[:, vert_idx]] * primal[E2ECV[:, primal_idx]] + ``` + """ + return im.call( + im.call("as_fieldop")( + im.lambda_("vert_shifted", "primal_shifted")( + im.multiplies_(im.deref("vert_shifted"), im.deref("primal_shifted")) + ), + edge_k_domain, + ) + )( + # arg: `vert_shifted` + im.call( + im.call("as_fieldop")( + im.lambda_("vert_no_shifted")( + im.deref(im.shift("E2C2V", vert_idx)("vert_no_shifted")) + ), + edge_k_domain, + ) + )( + vert, # arg: `vert_no_shifted` + ), + # end arg: `vert_shifted` + + # arg: `primal_shifted` + im.call( + im.call("as_fieldop")( + im.lambda_("primal_no_shifted")( + im.deref(im.shift("E2ECV", primal_idx)("primal_no_shifted")) + ), + edge_domain, + ) + )( + primal, # arg: `primal_no_shifted` + ), + # end arg: `primal_shifted` + ) + nabla4prog = itir.Program( id="nabla4_partial_fieldview", function_definitions=[], @@ -303,40 +352,7 @@ def build_nambla4_gtir_fieldview( ) )( # arg: `xn_0` - im.call( - im.call("as_fieldop")( - im.lambda_("u_vert_0", "pnvertv1_0")( - im.multiplies_(im.deref("u_vert_0"), im.deref("pnvertv1_0")) - ), - edge_k_domain, - ) - )( - # arg: `u_vert_0` - im.call( - im.call("as_fieldop")( - im.lambda_("u_vert_no_shifted")( - im.deref(im.shift("E2C2V", 0)("u_vert_no_shifted")) - ), - edge_k_domain, - ) - )( - "u_vert" # arg: `u_vert_no_shifted` - ), - # end arg: `u_vert_0` - - # arg: `pnvertv1_0` - im.call( - im.call("as_fieldop")( - im.lambda_("primal_normal_vert_v1_no_shifted")( - im.deref(im.shift("E2ECV", 0)("primal_normal_vert_v1_no_shifted")) - ), - edge_domain, - ) - )( - "primal_normal_vert_v1" # arg: `primal_normal_vert_v1_no_shifted` - ), - # end arg: `pnvertv1_0` - ), + shift_builder("u_vert", 0, "primal_normal_vert_v1", 0), # end arg: `xn_0` "xn_1", From 008209dd55b9dfd7332d88cd86fb3325fe126d04 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 13:00:55 +0200 Subject: [PATCH 121/235] It now works with the normal shiftuing stuff. --- my_playground/nabla4.py | 113 +++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index f9e793766f..f83e181d2a 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -57,28 +57,27 @@ def nabla4_np( u_vert: Field[[EdgeDim, KDim], wpfloat], + v_vert: Field[[EdgeDim, KDim], wpfloat], primal_normal_vert_v1: Field[[ECVDim], wpfloat], - - xn_1: Field[[EdgeDim, KDim], wpfloat], - xn_2: Field[[EdgeDim, KDim], wpfloat], - xn_3: Field[[EdgeDim, KDim], wpfloat], - nabv_tang: Field[[EdgeDim, KDim], wpfloat], + primal_normal_vert_v2: Field[[ECVDim], wpfloat], z_nabla2_e: Field[[EdgeDim, KDim], wpfloat], inv_vert_vert_length: Field[[EdgeDim], wpfloat], inv_primal_edge_length: Field[[EdgeDim], wpfloat], - # These are the offset providers E2C2V: NeighborTable, - + nabv_tang: Field[[EdgeDim, KDim], wpfloat], # FAKE **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[EdgeDim, KDim], wpfloat]: - primal_normal_vert_v1 = primal_normal_vert_v1.reshape(E2C2V.table.shape) + primal_normal_vert_v2 = primal_normal_vert_v2.reshape(E2C2V.table.shape) + u_vert_e2c2v = u_vert[E2C2V.table] + v_vert_e2c2v = v_vert[E2C2V.table] - u_vert_0 = u_vert[E2C2V.table[:, 0]] - pnvertv1_0 = primal_normal_vert_v1[:, 0] - - nabv_norm = (u_vert_0 * pnvertv1_0.reshape((-1, 1))) + xn_1 + xn_2 + xn_3 + xn_0 = u_vert_e2c2v[:, 2] * primal_normal_vert_v1[:, 2].reshape((-1, 1)) + xn_1 = v_vert_e2c2v[:, 2] * primal_normal_vert_v2[:, 2].reshape((-1, 1)) + xn_2 = u_vert_e2c2v[:, 3] * primal_normal_vert_v1[:, 3].reshape((-1, 1)) + xn_3 = v_vert_e2c2v[:, 3] * primal_normal_vert_v2[:, 3].reshape((-1, 1)) + nabv_norm = xn_0 + xn_1 + xn_2 + xn_3 N = nabv_norm - 2 * z_nabla2_e ell_v2 = inv_vert_vert_length**2 @@ -196,7 +195,6 @@ def build_nambla4_gtir_fieldview( ) -> itir.Program: """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" - edge_domain = im.call("unstructured_domain")( im.call("named_range")( itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" @@ -216,12 +214,11 @@ def build_nambla4_gtir_fieldview( E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) - def shift_builder( - vert: str, - vert_idx: int, - primal: str, - primal_idx: int, + vert: str, + vert_idx: int, + primal: str, + primal_idx: int, ) -> itir.FunCall: """Used to construct the shifting calculations. @@ -250,7 +247,6 @@ def shift_builder( vert, # arg: `vert_no_shifted` ), # end arg: `vert_shifted` - # arg: `primal_shifted` im.call( im.call("as_fieldop")( @@ -270,13 +266,10 @@ def shift_builder( function_definitions=[], params=[ itir.Sym(id="u_vert", type=VK_FTYPE), - # itir.Sym(id="v_vert", type=VK_FTYPE), + itir.Sym(id="v_vert", type=VK_FTYPE), itir.Sym(id="primal_normal_vert_v1", type=ECV_FTYPE), - - itir.Sym(id="xn_1", type=EK_FTYPE), - itir.Sym(id="xn_2", type=EK_FTYPE), - itir.Sym(id="xn_3", type=EK_FTYPE), - itir.Sym(id="nabv_tang", type=EK_FTYPE), + itir.Sym(id="primal_normal_vert_v2", type=ECV_FTYPE), + itir.Sym(id="nabv_tang", type=EK_FTYPE), # FAKE itir.Sym(id="z_nabla2_e", type=EK_FTYPE), itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), @@ -336,32 +329,53 @@ def shift_builder( ) )( # arg: `xn` - # u_vert(E2C2V[0]) * primal_normal_vert_v1(E2ECV[0]) || nx_0 - # + v_vert(E2C2V[0]) * primal_normal_vert_v2(E2ECV[0]) || xn_1 - # + u_vert(E2C2V[1]) * primal_normal_vert_v1(E2ECV[1]) || xn_2 - # + v_vert(E2C2V[1]) * primal_normal_vert_v2(E2ECV[1]) || xn_3 + # u_vert(E2C2V[2]) * primal_normal_vert_v1(E2ECV[2]) || nx_0 + # + v_vert(E2C2V[2]) * primal_normal_vert_v2(E2ECV[2]) || xn_1 + # + u_vert(E2C2V[3]) * primal_normal_vert_v1(E2ECV[3]) || xn_2 + # + v_vert(E2C2V[3]) * primal_normal_vert_v2(E2ECV[3]) || xn_3 im.call( im.call("as_fieldop")( - im.lambda_("xn_0", "xn_1", "xn_2", "xn_3")( - im.plus( - im.plus(im.deref("xn_0"), im.deref("xn_1")), - im.plus(im.deref("xn_2"), im.deref("xn_3")), - ) + im.lambda_("xn_0_p_1", "xn_2_p_3")( + im.plus(im.deref("xn_0_p_1"), im.deref("xn_2_p_3")) ), edge_k_domain, ) )( - # arg: `xn_0` - shift_builder("u_vert", 0, "primal_normal_vert_v1", 0), - # end arg: `xn_0` - - "xn_1", - - "xn_2", - - "xn_3", + # arg: `xn_0_p_1` + im.call( + im.call("as_fieldop")( + im.lambda_("xn_0", "xn_1")( + im.plus(im.deref("xn_0"), im.deref("xn_1")) + ), + edge_k_domain, + ) + )( + shift_builder( # arg: `xn_0` + "u_vert", 2, "primal_normal_vert_v1", 2 + ), + shift_builder( # arg: `xn_1` + "v_vert", 2, "primal_normal_vert_v2", 2 + ), + ), + # end arg: `xn_0_p_1` + # arg: `xn_2_p_3` + im.call( + im.call("as_fieldop")( + im.lambda_("xn_2", "xn_3")( + im.plus(im.deref("xn_2"), im.deref("xn_3")) + ), + edge_k_domain, + ) + )( + shift_builder( # arg: `xn_2` + "u_vert", 3, "primal_normal_vert_v1", 3 + ), + shift_builder( # arg: `xn_3` + "v_vert", 3, "primal_normal_vert_v2", 3 + ), + ), + # end arg: `xn_2_p_3` ), - # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( @@ -477,12 +491,14 @@ def verify_nabla4( "E2ECV": E2ECV_connectivity, } - xn_args = {f"xn_{i}": np.random.rand(num_edges, num_k_levels) for i in range(1, 4)} + # This is not yet computed + nabv_tang = np.random.rand(num_edges, num_k_levels) u_vert = np.random.rand(num_vertices, num_k_levels) + v_vert = np.random.rand(num_vertices, num_k_levels) primal_normal_vert_v1 = np.random.rand(num_edges * 4) + primal_normal_vert_v2 = np.random.rand(num_edges * 4) - nabv_tang = np.random.rand(num_edges, num_k_levels) z_nabla2_e = np.random.rand(num_edges, num_k_levels) inv_vert_vert_length = np.random.rand(num_edges) inv_primal_edge_length = np.random.rand(num_edges) @@ -498,14 +514,13 @@ def verify_nabla4( nab4=nab4, num_edges=num_edges, num_k_levels=num_k_levels, - u_vert=u_vert, + v_vert=v_vert, primal_normal_vert_v1=primal_normal_vert_v1, + primal_normal_vert_v2=primal_normal_vert_v2, ) - call_args.update(xn_args) call_args.update({f"connectivity_{k}": v.table.copy() for k, v in offset_provider.items()}) - SYMBS = make_syms(**call_args) sdfg(**call_args, **SYMBS) From b832acaa9310f99b80f7537d278e73fb7c065275 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 13:07:05 +0200 Subject: [PATCH 122/235] Now the full nabla4 should be ported. --- my_playground/nabla4.py | 62 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index f83e181d2a..965615953f 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -65,7 +65,6 @@ def nabla4_np( inv_primal_edge_length: Field[[EdgeDim], wpfloat], # These are the offset providers E2C2V: NeighborTable, - nabv_tang: Field[[EdgeDim, KDim], wpfloat], # FAKE **kwargs, # Allows to use the same call argument object as for the SDFG ) -> Field[[EdgeDim, KDim], wpfloat]: primal_normal_vert_v1 = primal_normal_vert_v1.reshape(E2C2V.table.shape) @@ -83,6 +82,12 @@ def nabla4_np( ell_v2 = inv_vert_vert_length**2 N_ellv2 = N * ell_v2.reshape((-1, 1)) + xt_0 = u_vert_e2c2v[:, 0] * primal_normal_vert_v1[:, 0].reshape((-1, 1)) + xt_1 = v_vert_e2c2v[:, 0] * primal_normal_vert_v2[:, 0].reshape((-1, 1)) + xt_2 = u_vert_e2c2v[:, 1] * primal_normal_vert_v1[:, 1].reshape((-1, 1)) + xt_3 = v_vert_e2c2v[:, 1] * primal_normal_vert_v2[:, 1].reshape((-1, 1)) + nabv_tang = xt_0 + xt_1 + xt_2 + xt_3 + T = nabv_tang - 2 * z_nabla2_e ell_e2 = inv_primal_edge_length**2 T_elle2 = T * ell_e2.reshape((-1, 1)) @@ -269,7 +274,6 @@ def shift_builder( itir.Sym(id="v_vert", type=VK_FTYPE), itir.Sym(id="primal_normal_vert_v1", type=ECV_FTYPE), itir.Sym(id="primal_normal_vert_v2", type=ECV_FTYPE), - itir.Sym(id="nabv_tang", type=EK_FTYPE), # FAKE itir.Sym(id="z_nabla2_e", type=EK_FTYPE), itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), @@ -376,6 +380,7 @@ def shift_builder( ), # end arg: `xn_2_p_3` ), + # end arg: `xn` # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( @@ -429,7 +434,54 @@ def shift_builder( ) )( # arg: `xt` - "nabv_tang", + # u_vert(E2C2V[0]) * primal_normal_vert_v1(E2ECV[0]) || nx_0 + # + v_vert(E2C2V[0]) * primal_normal_vert_v2(E2ECV[0]) || xt_1 + # + u_vert(E2C2V[1]) * primal_normal_vert_v1(E2ECV[1]) || xt_2 + # + v_vert(E2C2V[1]) * primal_normal_vert_v2(E2ECV[1]) || xt_3 + im.call( + im.call("as_fieldop")( + im.lambda_("xt_0_p_1", "xn_2_p_3")( + im.plus(im.deref("xt_0_p_1"), im.deref("xn_2_p_3")) + ), + edge_k_domain, + ) + )( + # arg: `xt_0_p_1` + im.call( + im.call("as_fieldop")( + im.lambda_("xt_0", "xn_1")( + im.plus(im.deref("xt_0"), im.deref("xn_1")) + ), + edge_k_domain, + ) + )( + shift_builder( # arg: `xt_0` + "u_vert", 0, "primal_normal_vert_v1", 0 + ), + shift_builder( # arg: `xt_1` + "v_vert", 0, "primal_normal_vert_v2", 0 + ), + ), + # end arg: `xt_0_p_1` + # arg: `xt_2_p_3` + im.call( + im.call("as_fieldop")( + im.lambda_("xt_2", "xn_3")( + im.plus(im.deref("xt_2"), im.deref("xn_3")) + ), + edge_k_domain, + ) + )( + shift_builder( # arg: `xt_2` + "u_vert", 1, "primal_normal_vert_v1", 1 + ), + shift_builder( # arg: `xt_3` + "v_vert", 1, "primal_normal_vert_v2", 1 + ), + ), + # end arg: `xt_2_p_3` + ), + # end arg: `xt` # arg: `z_nabla2_e2` im.call( im.call("as_fieldop")( @@ -491,9 +543,6 @@ def verify_nabla4( "E2ECV": E2ECV_connectivity, } - # This is not yet computed - nabv_tang = np.random.rand(num_edges, num_k_levels) - u_vert = np.random.rand(num_vertices, num_k_levels) v_vert = np.random.rand(num_vertices, num_k_levels) primal_normal_vert_v1 = np.random.rand(num_edges * 4) @@ -507,7 +556,6 @@ def verify_nabla4( sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) call_args = dict( - nabv_tang=nabv_tang, z_nabla2_e=z_nabla2_e, inv_vert_vert_length=inv_vert_vert_length, inv_primal_edge_length=inv_primal_edge_length, From 94ab9d75d7e7da1cc682f62d017c595648378de2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 13:20:46 +0200 Subject: [PATCH 123/235] Restructured the code and removed the inline version. --- my_playground/nabla4.py | 223 +++++++++++----------------------------- 1 file changed, 60 insertions(+), 163 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index 965615953f..a497bda6a4 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -53,6 +53,10 @@ wpfloat = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) +VK_FTYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) +EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) +E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) +ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) def nabla4_np( @@ -95,103 +99,66 @@ def nabla4_np( return 4 * (N_ellv2 + T_elle2) -def build_nambla4_gtir_inline( - num_edges: int, - num_k_levels: int, -) -> itir.Program: - """Creates the `nabla4` stencil where the computations are already inlined.""" - - edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" - ), - im.call("named_range")( - itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" - ), - ) +# Dimension we operate on. +edge_k_domain = im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" + ), + im.call("named_range")(itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels"), +) +edge_domain = im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" + ), +) - EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) - E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) - nabla4prog = itir.Program( - id="nabla4_partial_inline", - function_definitions=[], - params=[ - itir.Sym(id="nabv_norm", type=EK_FTYPE), - itir.Sym(id="nabv_tang", type=EK_FTYPE), - itir.Sym(id="z_nabla2_e", type=EK_FTYPE), - itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), - itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), - itir.Sym(id="nab4", type=EK_FTYPE), - itir.Sym(id="num_edges", type=SIZE_TYPE), - itir.Sym(id="num_k_levels", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_( - "z_nabla2_e2", - "const_4", - "nabv_norm", - "inv_vert_vert_length", - "nabv_tang", - "inv_primal_edge_length", - )( - im.multiplies_( - im.plus( - im.multiplies_( - im.minus( - im.deref("nabv_norm"), - im.deref("z_nabla2_e2"), - ), - im.multiplies_( - im.deref("inv_vert_vert_length"), - im.deref("inv_vert_vert_length"), - ), - ), - im.multiplies_( - im.minus( - im.deref("nabv_tang"), - im.deref("z_nabla2_e2"), - ), - im.multiplies_( - im.deref("inv_primal_edge_length"), - im.deref("inv_primal_edge_length"), - ), - ), - ), - im.deref("const_4"), - ) - ), - edge_k_domain, - ) - )( - # arg: `z_nabla2_e2` - im.call( - im.call("as_fieldop")( - im.lambda_("x", "const_2")( - im.multiplies_(im.deref("x"), im.deref("const_2")) - ), - edge_k_domain, - ) - )("z_nabla2_e", 2.0), - # end arg: `z_nabla2_e2` - # arg: `const_4` - 4.0, - # Same name as in the lambda and are argument to the program. - "nabv_norm", - "inv_vert_vert_length", - "nabv_tang", - "inv_primal_edge_length", +def shift_builder( + vert: str, + vert_idx: int, + primal: str, + primal_idx: int, +) -> itir.FunCall: + """Used to construct the shifting calculations. + + This function generates the IR for the expression: + ``` + vert[E2C2V[:, vert_idx]] * primal[E2ECV[:, primal_idx]] + ``` + """ + return im.call( + im.call("as_fieldop")( + im.lambda_("vert_shifted", "primal_shifted")( + im.multiplies_(im.deref("vert_shifted"), im.deref("primal_shifted")) + ), + edge_k_domain, + ) + )( + # arg: `vert_shifted` + im.call( + im.call("as_fieldop")( + im.lambda_("vert_no_shifted")( + im.deref(im.shift("E2C2V", vert_idx)("vert_no_shifted")) ), - domain=edge_k_domain, - target=itir.SymRef(id="nab4"), + edge_k_domain, ) - ], + )( + vert, # arg: `vert_no_shifted` + ), + # end arg: `vert_shifted` + # arg: `primal_shifted` + im.call( + im.call("as_fieldop")( + im.lambda_("primal_no_shifted")( + im.deref(im.shift("E2ECV", primal_idx)("primal_no_shifted")) + ), + edge_domain, + ) + )( + primal, # arg: `primal_no_shifted` + ), + # end arg: `primal_shifted` ) - return nabla4prog def build_nambla4_gtir_fieldview( @@ -200,72 +167,6 @@ def build_nambla4_gtir_fieldview( ) -> itir.Program: """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" - edge_domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" - ), - ) - edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" - ), - im.call("named_range")( - itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels" - ), - ) - - VK_FTYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) - EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) - E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) - ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) - - def shift_builder( - vert: str, - vert_idx: int, - primal: str, - primal_idx: int, - ) -> itir.FunCall: - """Used to construct the shifting calculations. - - This function generates the IR for the expression: - ``` - vert[E2C2V[:, vert_idx]] * primal[E2ECV[:, primal_idx]] - ``` - """ - return im.call( - im.call("as_fieldop")( - im.lambda_("vert_shifted", "primal_shifted")( - im.multiplies_(im.deref("vert_shifted"), im.deref("primal_shifted")) - ), - edge_k_domain, - ) - )( - # arg: `vert_shifted` - im.call( - im.call("as_fieldop")( - im.lambda_("vert_no_shifted")( - im.deref(im.shift("E2C2V", vert_idx)("vert_no_shifted")) - ), - edge_k_domain, - ) - )( - vert, # arg: `vert_no_shifted` - ), - # end arg: `vert_shifted` - # arg: `primal_shifted` - im.call( - im.call("as_fieldop")( - im.lambda_("primal_no_shifted")( - im.deref(im.shift("E2ECV", primal_idx)("primal_no_shifted")) - ), - edge_domain, - ) - )( - primal, # arg: `primal_no_shifted` - ), - # end arg: `primal_shifted` - ) - nabla4prog = itir.Program( id="nabla4_partial_fieldview", function_definitions=[], @@ -530,10 +431,7 @@ def verify_nabla4( ) elif version == "inline": - nabla4prog = build_nambla4_gtir_inline( - num_edges=num_edges, - num_k_levels=num_k_levels, - ) + raise NotImplementedError("Inline version is no longer supported.") else: raise ValueError(f"The version `{version}` is now known.") @@ -579,5 +477,4 @@ def verify_nabla4( if "__main__" == __name__: - # verify_nabla4("inline") verify_nabla4("fieldview") From 04cde84b445247a554c97269718cdc6cbde677fa Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 8 Jul 2024 13:26:54 +0200 Subject: [PATCH 124/235] Made some small update. --- .../runners/dace_fieldview/transformations/auto_opt.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 09d2132e9f..444e5506b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -17,8 +17,8 @@ from typing import Any import dace -from dace.transformation.auto import auto_optimize as dace_aoptimize from dace.transformation import dataflow as dace_dataflow +from dace.transformation.auto import auto_optimize as dace_aoptimize def dace_auto_optimize( @@ -49,7 +49,6 @@ def dace_auto_optimize( # This should get rid of some of teh additional transients we have added. sdfg.simplify() - return sdfg @@ -68,8 +67,6 @@ def gt_auto_optimize( # Initial cleaning sdfg.simplify() - - # Due to the structure of the generated SDFG getting rid of Maps, # i.e. fusing them, is the best we can currently do. sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) @@ -82,4 +79,3 @@ def gt_auto_optimize( sdfg.simplify() return sdfg - From 3dd08600461d08a22b6f79cff4800922481df555 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 9 Jul 2024 16:26:21 +0200 Subject: [PATCH 125/235] This is the base of all fusion operations. --- .../transformations/map_fusion_helper.py | 571 ++++++++++++++++++ 1 file changed, 571 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py new file mode 100644 index 0000000000..e9fb7e077c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -0,0 +1,571 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Implements Helper functionaliyies for map fusion""" + +from typing import Any, Iterable, Optional, Sequence, Union + +from dace import subsets +from dace.sdfg import ( + SDFG, + SDFGState, + data, + nodes, + properties, + transformation as dace_transformation, +) +from dace.transformation import helpers + + +@properties.make_properties +class MapFusionHelper(dace_transformation.SingleStateTransformation): + """ + Contains common part of the map fusion for parallel and serial map fusion. + + See also [this HackMD document](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + about the underlying assumption this transformation makes. + + After every transformation that manipulates the state machine, you shouls recreate + the transformation. + """ + + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are on the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + shared_transients = properties.DictProperty( + key_type=SDFG, + value_type=set[str], + default=None, + allow_none=True, + desc="Maps SDFGs to the set of array transients that can not be removed. " + "The variable acts as a cache, and is managed by 'can_transient_be_removed()'.", + ) + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + self.shared_transients = {} + + @classmethod + def expressions(cls) -> bool: + raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + Note: + - This function dos not remove the `from_node` but it will have degree + zero and have no connectors. + - If this function fails, the SDFG is in an invalid state. + - Usually this function should be called twice per Map scope, once for the + entry node and once for the exit node. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine if which connections we have to migrate + # We only consider the in edges, for Map exits it does not matter, but for + # Map entries, we need it for the dynamic map range feature. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + assert isinstance(from_node, nodes.MapEntry) + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError( + f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented." + ) + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." + ) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + # There is no other edge that we have to consider, so we just end here + continue + + # We have a Passthrough connection, i.e. there exists a `OUT_` connector + # thus we now have to migrate the two edges. + + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + assert ( + state.in_degree(from_node) == 0 + ), f"After moving source node '{from_node}' still has an input degree of {state.in_degree(from_node)}" + assert ( + state.out_degree(from_node) == 0 + ), f"After moving source node '{from_node}' still has an output degree of {state.in_degree(from_node)}" + + def map_parameter_compatible( + self, + map_1: nodes.Map, + map_2: nodes.Map, + state: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> bool: + """Checks if `map_1` is compatible with `map_1`. + + The check follows the following rules: + - The names of the map variables must be the same, i.e. no renaming + is performed. + - The ranges must be the same. + """ + range_1: subsets.Range = map_1.range + params_1: Sequence[str] = map_1.params + range_2: subsets.Range = map_2.range + params_2: Sequence[str] = map_2.params + + # The maps are only fuseable if we have an exact match in the parameter names + # this is because we do not do any renaming. + if set(params_1) != set(params_2): + return False + + # Maps the name of a parameter to the dimension index + param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} + param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + + # To fuse the two maps the ranges must have the same ranges + for pname in params_1: + idx_1 = param_dim_map_1[pname] + idx_2 = param_dim_map_2[pname] + # TODO(phimuell): do we need to call simplify + if range_1[idx_1] != range_2[idx_2]: + return False + + return True + + def can_transient_be_removed( + self, + transient: Union[str, nodes.AccessNode], + sdfg: SDFG, + ) -> bool: + """Can `transient` be removed. + + Essentially the function tests if the transient `transient` is needed to + transmit information from one state to the other. The function will first + look consult `self.shared_transients`, if the SDFG is not known the function + will compute the set of transients that have to be kept alive. + + If `transient` refers to a scalar the function will return `False`, as + a scalar can not be removed. + + Args: + transient: The transient that should be checked. + sdfg: The SDFG containing the array. + """ + + if sdfg not in self.shared_transients: + # SDFG is not known, so we have to compute the set of all transients that + # have to be kept alive. This set is given by all transients that are + # source nodes; We currently ignore scalars. + shared_sdfg_transients: set[str] = set() + for state in sdfg.states(): + for acnode in filter( + lambda node: isinstance(node, nodes.AccessNode), state.sink_nodes() + ): + desc = sdfg.arrays[acnode.data] + if desc.transient and isinstance(desc, data.Array): + shared_sdfg_transients.add(acnode.data) + self.shared_transients[sdfg] = shared_sdfg_transients + + if isinstance(transient, nodes.AccessNode): + transient = transient.data + + desc: data.Data = sdfg.arrays[transient] # type: ignore[no-redef] + if isinstance(desc, data.View): + return False + if isinstance(desc, data.Scalar): + return False + return transient not in self.shared_transients[sdfg] + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + tuple[ + set[nodes.MultiConnectorEdge], + set[nodes.MultiConnectorEdge], + set[nodes.MultiConnectorEdge], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit and enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not needed anywhere else, thus it will + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and are recreated + as output of the second map. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + serially, the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: set[nodes.MultiConnectorEdge] = set() + exclusive_outputs: set[nodes.MultiConnectorEdge] = set() + shared_outputs: set[nodes.MultiConnectorEdge] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: set[nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # Empty Memlets are currently not supported. + if out_edge.data.is_empty(): + return None + + # Now let's look at all nodes that are downstream of the intermediate node. + # This, among other thing will tell us, how we have to handle this node. + downstream_nodes = self.all_nodes_between( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ) + + # If `downstream_nodes` is `None` it means that `map_entry_2` was never + # reached, thus `intermediate_node` does not enter the second map and + # the node is a pure output node. + if downstream_nodes is None: + pure_outputs.add(out_edge) + continue + # + + # The following tests, before we start handle intermediate nodes, are + # _after_ the pure node test for a reason, because this allows us to + # handle more exotic cases for these nodes. + + # In case the intermediate has more than one entry, all must come from the + # first map, otherwise we can not fuse them. + if state.in_degree(intermediate_node) != 1: + # TODO(phimuell): In some cases it can be possible to fuse such + # nodes, but we will not handle them for the time being. + return None + + # It happens can be that multiple edges at the `IN_` connector of the + # first exit map converges, but there is only one edge leaving the exit. + # TODO(phimuell): Handle this case properly. + inner_collector_edges = state.in_edges_by_connector( + intermediate_node, "IN_" + out_edge.src_conn[3:] + ) + if len(inner_collector_edges) > 1: + return None + + # For us an intermediate node must always be an access node, pointing to a + # transient array, since it is the only thing that we know how to handle. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: data.Data = intermediate_node.desc(sdfg) + if not intermediate_desc.transient: + return None + if isinstance(intermediate_desc, data.View): + return None + + # There are two restrictions we have on the intermediate output sets. + # First, we do not allow that they are involved in WCR (as they are + # currently not handled by the implementation) and second, that the + # "producer" generate only one element, this is actual crucial, as we + # assume that we can freely recreate them, a simples example consider + # that a Tasklet outputs "rows" then we can not handle the rest in + # columns. For that reason we check the generating Memlets. + for _, produce_edge in self.find_upstream_producers(state, out_edge): + if produce_edge.data.wcr is not None: + return None + if produce_edge.data.num_elements() != 1: + return None + # TODO(phimuell): Check that the producing is only pointwise. + + if len(downstream_nodes) == 0: + # There is nothing between intermediate node and the entry of the + # second map, thus the edge belongs either in `\mathbb{S}` or + # `\mathbb{E}`, to which one depends on how it is used. + + # This is a very special situation, i.e. the access node has many + # different connections to the second map entry, this is a special + # case that we do not handle, instead simplify should be called. + if state.out_degree(intermediate_node) != 1: + return None + + # There are certain nodes, for example Tasklets, that needs the whole + # array as input. Thus it can not be removed, because the node might + # need the whole array. + # TODO(phimuell): This is true for JaCe but also for GT4Py? + for _, feed_edge in self.find_downstream_consumers( + state=state, begin=intermediate_node + ): + if feed_edge.data.num_elements() != 1: + return None + # TODO(phimuell): Check that the consuming is only pointwise. + + if self.can_transient_be_removed(intermediate_node, sdfg): + # The transient can be removed, thus it is exclusive. + exclusive_outputs.add(out_edge) + else: + # The transient can not be removed, to it must be shared. + shared_outputs.add(out_edge) + continue + + else: + # These is no single connection from the intermediate node to the + # second map, but many. For now we will only handle a very special + # case that makes the node to a shared intermediate node: + # All output connections of the intermediate node either lead: + # - directly to the second map entry node and does not define a + # dynamic map range, and can actually be removed. + # - have no connection to the second map entry at all. + for edge in state.out_edges(intermediate_node): + if edge.dst is map_entry_2: + # The edge immediately leads to the second map. + for consumer_node, feed_edge in self.find_downstream_consumers( + state=state, begin=edge + ): + # Consumer needs the whole array. + if feed_edge.data.num_elements() != 1: + return None + # Defines a dynamic map range + if consumer_node is map_entry_2: + return None + else: + # Ensure that there is no path that leads to the second map. + if ( + self.all_nodes_between(graph=state, begin=edge.dst, end=map_entry_2) + is not None + ): + return None + + # If we are here, then we know that the node is a shared output + shared_outputs.add(out_edge) + continue + + assert exclusive_outputs or shared_outputs or pure_outputs + return (pure_outputs, exclusive_outputs, shared_outputs) + + def all_nodes_between( + self, + graph: SDFG | SDFGState, + begin: nodes.Node, + end: nodes.Node, + reverse: bool = False, + ) -> set[nodes.Node] | None: + """Returns all nodes that are reachable from `begin` but bound by `end`. + + What the function does is, that it starts a DFS starting at `begin`, which is + not part of the returned set, every edge that goes to `end` will be considered + to not exists. + In case `end` is never found the function will return `None`. + + If `reverse` is set to `True` the function will start exploring at `end` and + follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The terminator node of the DFS. + reverse: Perform a backward DFS. + + Notes: + - The returned set will never contain the node `begin`. + - The returned set will also contain the nodes of path that starts at + `begin` and ends at a node that is not `end`. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + if reverse: + begin, end = end, begin + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.src for edge in graph.in_edges(node)) + + to_visit: list[nodes.Node] = [begin] + seen: set[nodes.Node] = set() + found_end: bool = False + + while len(to_visit) > 0: + n: nodes.Node = to_visit.pop() + if n == end: + found_end = True + continue + elif n in seen: + continue + seen.add(n) + to_visit.extend(next_nodes(n)) + + if not found_end: + return None + + seen.discard(begin) + return seen + + def find_downstream_consumers( + self, + state: SDFGState, + begin: nodes.Node | nodes.MultiConnectorEdge, + only_tasklets: bool = False, + reverse: bool = False, + ) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: + """Find all downstream connectors of `begin`. + + A consumer, in this sense, is any node that is neither an entry nor an exit + node. The function returns a set storing the pairs, the first element is the + node that acts as consumer and the second is the edge that leads to it. + By setting `only_tasklets` the nodes the function finds are only Tasklets. + + To find this set the function starts a search at `begin`, however, it is also + possible to pass an edge as `begin`. + If `reverse` is `True` the function essentially finds the producers that are + upstream. + + Args: + state: The state in which to look for the consumers. + begin: The initial node that from which the search starts. + only_tasklets: Return only Tasklets. + reverse: Follow the reverse direction. + """ + if isinstance(begin, nodes.MultiConnectorEdge): + to_visit: list[nodes.MultiConnectorEdge] = [begin] + elif reverse: + to_visit = list(state.in_edges(begin)) + else: + to_visit = list(state.out_edges(begin)) + seen: set[nodes.MultiConnectorEdge] = set() + found: set[tuple[nodes.Node, nodes.MultiConnectorEdge]] = set() + + while len(to_visit) != 0: + curr_edge: nodes.MultiConnectorEdge = to_visit.pop() + next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst + + if curr_edge in seen: + continue + seen.add(curr_edge) + + if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): + if reverse: + target_conn = curr_edge.src_conn[4:] + new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) + else: + # In forward mode a Map entry could also mean the definition of a + # dynamic map range. + if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( + next_node, nodes.MapEntry + ): + # This edge defines a dynamic map range, which is a consumer + if not only_tasklets: + found.add((next_node, curr_edge)) + continue + target_conn = curr_edge.dst_conn[3:] + new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) + to_visit.extend(new_edges) + else: + if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): + continue + found.add((next_node, curr_edge)) + + return found + + def find_upstream_producers( + self, + state: SDFGState, + begin: nodes.Node | nodes.MultiConnectorEdge, + only_tasklets: bool = False, + ) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: + """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" + return self.find_downstream_consumers( + state=state, + begin=begin, + only_tasklets=only_tasklets, + reverse=True, + ) From 3178b71d4af9ddc56a203adf6aa45a5dd4527618 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 10 Jul 2024 08:41:38 +0200 Subject: [PATCH 126/235] Reworked some parts. --- .../transformations/map_fusion_helper.py | 167 ++++++++++++------ .../dace_fieldview/transformations/util.py | 31 ++++ 2 files changed, 146 insertions(+), 52 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index e9fb7e077c..0b7b9f6760 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -14,8 +14,10 @@ """Implements Helper functionaliyies for map fusion""" +import functools from typing import Any, Iterable, Optional, Sequence, Union +import dace from dace import subsets from dace.sdfg import ( SDFG, @@ -27,6 +29,8 @@ ) from dace.transformation import helpers +from . import util + @properties.make_properties class MapFusionHelper(dace_transformation.SingleStateTransformation): @@ -78,6 +82,54 @@ def __init__( def expressions(cls) -> bool: raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + def can_be_applied( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + graph: Union[SDFGState, SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs some checks if the maps can be fused. + + This function does not follow the standard interface of DaCe transformations. + Instead it checks if the two maps can be fused, by comparing: + - The scope of the maps. + - The scheduling of the maps. + - The map parameters. + + However, for performance reasons, the function does not compute the node + partition. + """ + + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + + # ensure that both have the same schedule + if map_entry_1.map.schedule != map_entry_2.map.schedule: + return False + + # Fusing is only possible if our two entries are in the same scope. + scope = graph.scope_dict() + if scope[map_entry_1] != scope[map_entry_2]: + return False + elif self.only_inner_maps: + if scope[map_entry_1] is None: + return False + elif self.only_toplevel_maps: + if scope[map_entry_1] is not None: + return False + elif util.is_nested_sdfg(sdfg): + return False + + # We will now check if there exists a remapping that we can use. + if not self.map_parameter_compatible( + map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + ): + return False + + return True + def relocate_nodes( self, from_node: Union[nodes.MapExit, nodes.MapEntry], @@ -304,18 +356,19 @@ def partition_first_outputs( processed_inter_nodes.add(intermediate_node) # Empty Memlets are currently not supported. + # However, they are much more important in entry nodes. if out_edge.data.is_empty(): return None # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other thing will tell us, how we have to handle this node. + # This, among other things, will tell us, how we have to handle this node. downstream_nodes = self.all_nodes_between( graph=state, begin=intermediate_node, end=map_entry_2, ) - # If `downstream_nodes` is `None` it means that `map_entry_2` was never + # If `downstream_nodes` is `None` this means that `map_entry_2` was never # reached, thus `intermediate_node` does not enter the second map and # the node is a pure output node. if downstream_nodes is None: @@ -323,19 +376,20 @@ def partition_first_outputs( continue # - # The following tests, before we start handle intermediate nodes, are - # _after_ the pure node test for a reason, because this allows us to - # handle more exotic cases for these nodes. + # The following tests are _after_ we have determined if we have a pure + # output node for a reason, as this allows us to handle more exotic + # pure node cases, as handling them is essentially rerouting an edge. # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. + # first map, otherwise we can not fuse them. Currently we restrict this + # even further by saying that it has only one incoming Memlet. if state.in_degree(intermediate_node) != 1: - # TODO(phimuell): In some cases it can be possible to fuse such - # nodes, but we will not handle them for the time being. + # TODO(phimuell): handle this case. return None - # It happens can be that multiple edges at the `IN_` connector of the - # first exit map converges, but there is only one edge leaving the exit. + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. inner_collector_edges = state.in_edges_by_connector( intermediate_node, "IN_" + out_edge.src_conn[3:] @@ -344,7 +398,7 @@ def partition_first_outputs( return None # For us an intermediate node must always be an access node, pointing to a - # transient array, since it is the only thing that we know how to handle. + # transient value, since it is the only thing that we know how to handle. if not isinstance(intermediate_node, nodes.AccessNode): return None intermediate_desc: data.Data = intermediate_node.desc(sdfg) @@ -353,19 +407,24 @@ def partition_first_outputs( if isinstance(intermediate_desc, data.View): return None - # There are two restrictions we have on the intermediate output sets. - # First, we do not allow that they are involved in WCR (as they are - # currently not handled by the implementation) and second, that the - # "producer" generate only one element, this is actual crucial, as we - # assume that we can freely recreate them, a simples example consider - # that a Tasklet outputs "rows" then we can not handle the rest in - # columns. For that reason we check the generating Memlets. + # There are some restrictions we have on intermediate nodes. The first one + # is that we do not allow WCR, this is because they need special handling + # which is currently not implement (the DaCe transformation has this + # restriction as well). The second one is that we can reduce the + # intermediate node and only feed a part into the second map, consider + # the case `b = a + 1; return b + 2`, where we have arrays. In this + # example only a single element must be available to the second map. + # However, this is hard to check so we will make a simplification. + # First we will not check it at the producer, but at the consumer point. + # There we assume if the consumer does _not consume the whole_ + # intermediate array, then we can decompose the intermediate, by setting + # the map iteration index to zero and recover the shape, see + # implementation in the actual fusion routine. + # This is an assumption that is in most cases correct, but not always. + # However, doing it correctly is extremely complex. for _, produce_edge in self.find_upstream_producers(state, out_edge): if produce_edge.data.wcr is not None: return None - if produce_edge.data.num_elements() != 1: - return None - # TODO(phimuell): Check that the producing is only pointwise. if len(downstream_nodes) == 0: # There is nothing between intermediate node and the entry of the @@ -378,53 +437,57 @@ def partition_first_outputs( if state.out_degree(intermediate_node) != 1: return None - # There are certain nodes, for example Tasklets, that needs the whole - # array as input. Thus it can not be removed, because the node might - # need the whole array. - # TODO(phimuell): This is true for JaCe but also for GT4Py? - for _, feed_edge in self.find_downstream_consumers( - state=state, begin=intermediate_node - ): - if feed_edge.data.num_elements() != 1: + # Certain nodes need more than one element as input. As explained + # above, in this situation we assume that we can naturally decompose + # them iff the node does not consume that whole intermediate. + # Furthermore, it can not be a dynamic map range. + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + consumers = self.find_downstream_consumers(state=state, begin=intermediate_node) + for consumer_node, feed_edge in consumers: + # TODO(phimuell): Improve this approximation. + if feed_edge.data.num_elements() == intermediate_size: + return None + if consumer_node is map_entry_2: # Dynamic map range. return None - # TODO(phimuell): Check that the consuming is only pointwise. + # Note that "remove" has a special meaning here, regardless of the + # output of the check function, from within the second map we remove + # the intermediate, it has more the meaning of "do we need to + # reconstruct it after the second map again?". if self.can_transient_be_removed(intermediate_node, sdfg): - # The transient can be removed, thus it is exclusive. exclusive_outputs.add(out_edge) else: - # The transient can not be removed, to it must be shared. shared_outputs.add(out_edge) continue else: - # These is no single connection from the intermediate node to the - # second map, but many. For now we will only handle a very special - # case that makes the node to a shared intermediate node: - # All output connections of the intermediate node either lead: - # - directly to the second map entry node and does not define a - # dynamic map range, and can actually be removed. - # - have no connection to the second map entry at all. + # These is not only a single connection from the intermediate node to + # the second map, but the intermediate has more connection, thus + # the node might belong to the shared outputs. Of the many different + # possibilities, we only consider a single case: + # - The intermediate has a single connection to the second map, that + # fulfills the restriction outlined above. + # - All other connections have no connection to the second map. + found_second_entry = False + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) for edge in state.out_edges(intermediate_node): if edge.dst is map_entry_2: - # The edge immediately leads to the second map. - for consumer_node, feed_edge in self.find_downstream_consumers( - state=state, begin=edge - ): - # Consumer needs the whole array. - if feed_edge.data.num_elements() != 1: + if found_second_entry: # The second map was found again. + return None + found_second_entry = True + consumers = self.find_downstream_consumers(state=state, begin=edge) + for consumer_node, feed_edge in consumers: + if feed_edge.data.num_elements() == intermediate_size: return None - # Defines a dynamic map range - if consumer_node is map_entry_2: + if consumer_node is map_entry_2: # Dynamic map range return None else: # Ensure that there is no path that leads to the second map. - if ( - self.all_nodes_between(graph=state, begin=edge.dst, end=map_entry_2) - is not None - ): + after_intermdiate_node = self.all_nodes_between( + graph=state, begin=edge.dst, end=map_entry_2 + ) + if after_intermdiate_node is not None: return None - # If we are here, then we know that the node is a shared output shared_outputs.add(out_edge) continue diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py new file mode 100644 index 0000000000..b549e2e76e --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -0,0 +1,31 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Common functionality for the transformations.""" + +import dace + + +def is_nested_sdfg(sdfg: dace.SDFG) -> bool: + """Tests if `sdfg` is a neseted sdfg.""" + if isinstance(sdfg, dace.SDFGState): + sdfg = sdfg.parent + if isinstance(sdfg, dace.nodes.NestedSDFG): + return True + elif isinstance(sdfg, dace.SDFG): + if sdfg.parent_nsdfg_node is not None: + return True + return False + else: + raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") From 66c5fcdc095276c5b219f7fcb8a035bb3ff532d4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 10 Jul 2024 11:46:38 +0200 Subject: [PATCH 127/235] Address review comments --- src/gt4py/next/iterator/ir.py | 2 +- .../gtir_builtin_translators.py | 260 +++++++++++------- .../dace_fieldview/gtir_dace_backend.py | 23 +- .../dace_fieldview/gtir_python_codegen.py | 126 +++++++++ .../runners/dace_fieldview/gtir_to_sdfg.py | 70 +++-- .../runners/dace_fieldview/gtir_to_tasklet.py | 131 ++------- .../runners/dace_fieldview/sdfg_builder.py | 29 -- .../runners/dace_fieldview/utility.py | 40 +-- .../runners_tests/test_dace_fieldview.py | 142 +++++----- 9 files changed, 456 insertions(+), 367 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index d62c65f6f9..623106f303 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -194,7 +194,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib GTIR_BUILTINS = { *BUILTINS, "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) - "select", # `select(cond, field_a, field_b)` creates the field on one branch or the other + "cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69ae5f33d6..1a95fb3d01 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -14,110 +14,193 @@ import abc -from dataclasses import dataclass -from typing import Optional, TypeAlias +from typing import Any, Final, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs -from gt4py.next.common import Dimension -from gt4py.next.iterator import ir as itir +from gt4py.eve import concepts +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_python_codegen, gtir_to_tasklet, utility as dace_fieldview_util, ) -from gt4py.next.program_processors.runners.dace_fieldview.sdfg_builder import SDFGBuilder from gt4py.next.type_system import type_specifications as ts # Define aliases for return types -SDFGField: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] +TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] -DIMENSION_INDEX_FMT = "i_{dim}" -ITERATOR_INDEX_DTYPE = dace.int32 # type of iterator indexes +IteratorIndexFmt: Final[str] = "i_{dim}" +IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes -@dataclass(frozen=True) -class PrimitiveTranslator(abc.ABC): +class SDFGBuilder(Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + pass + + @abc.abstractmethod + def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + pass + + @abc.abstractmethod + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + pass + + +class PrimitiveTranslator(Protocol): @abc.abstractmethod def __call__( self, - node: itir.Node, + node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: SDFGBuilder, - ) -> list[SDFGField]: # keep only call, same interface for all primitives - """Creates the dataflow subgraph representing a GTIR builtin function. + ) -> list[TemporaryData]: + """Creates the dataflow subgraph representing a GTIR primitive function. This method is used by derived classes to build a specialized subgraph - for a specific builtin function. - - Returns a list of SDFG nodes and the associated GT4Py data type. - - The GT4Py data type is useful in the case of fields, because it provides - information on the field domain (e.g. order of dimensions, types of dimensions). + for a specific GTIR primitive function. + + Arguments: + node: The GTIR node describing the primitive to be lowered + sdfg: The SDFG where the primitive subgraph should be instantiated + state: The SDFG state where the result of the primitive function should be made available + sdfg_builder: The object responsible for visiting child nodes of the primitive node. + reduce_identity: The value of the reduction identity, in case the primitive node + is visited in the context of a reduction expression. This value is used + by the `neighbors` primitive to provide the value of skip neighbors. + + Returns: + A list of data access nodes and the associated GT4Py data type, which provide + access to the result of the primitive subgraph. The GT4Py data type is useful + in the case the returned data is an array, because the type provdes the domain + information (e.g. order of dimensions, dimension types). """ class AsFieldOp(PrimitiveTranslator): """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - callable_args: list[itir.Expr] - - def __init__(self, node_args: list[itir.Expr]): - self.callable_args = node_args - - def __call__( - self, - node: itir.Node, + @classmethod + def _parse_node_args( + cls, + args: list[gtir.Expr], sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: SDFGBuilder, - ) -> list[SDFGField]: - assert cpm.is_call_to(node, "as_fieldop") - assert len(node.args) == 2 - stencil_expr, domain_expr = node.args - # expect stencil (represented as a lambda function) as first argument - assert isinstance(stencil_expr, itir.Lambda) - # the domain of the field operator is passed as second argument - assert isinstance(domain_expr, itir.FunCall) - - # add local storage to compute the field operator over the given domain - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - field_domain = dace_fieldview_util.get_domain(domain_expr) - - # first visit the list of arguments and build a symbol map + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], + ) -> list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr]: stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in self.callable_args: - fields: list[SDFGField] = sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) + for arg in args: + fields: list[TemporaryData] = sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) assert len(fields) == 1 data_node, arg_type = fields[0] # require all argument nodes to be data access nodes (no symbols) assert isinstance(data_node, dace.nodes.AccessNode) + arg_definition: gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr if isinstance(arg_type, ts.ScalarType): - scalar_arg = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - stencil_args.append(scalar_arg) + arg_definition = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) else: assert isinstance(arg_type, ts.FieldType) - indices: dict[Dimension, gtir_to_tasklet.IteratorIndexExpr] = { + indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { dim: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(DIMENSION_INDEX_FMT.format(dim=dim.value)), - ITERATOR_INDEX_DTYPE, + dace.symbolic.SymExpr(IteratorIndexFmt.format(dim=dim.value)), + IteratorIndexDType, ) - for dim, _, _ in field_domain + for dim, _, _ in domain } - iterator_arg = gtir_to_tasklet.IteratorExpr( + arg_definition = gtir_to_tasklet.IteratorExpr( data_node, arg_type.dims, indices, ) - stencil_args.append(iterator_arg) + stencil_args.append(arg_definition) + + return stencil_args + + @classmethod + def _create_temporary_field( + cls, + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], + node_type: ts.ScalarType, + output_desc: dace.data.Data, + output_field_type: ts.DataType, + ) -> tuple[dace.nodes.AccessNode, ts.FieldType]: + domain_dims, domain_lbs, domain_ubs = zip(*domain) + field_dims = list(domain_dims) + field_shape = [ + # diff between upper and lower bound + (ub - lb) + for lb, ub in zip(domain_lbs, domain_ubs) + ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(domain_lbs): + field_offset = [-lb for lb in domain_lbs] + + if isinstance(output_desc, dace.data.Array): + # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) + assert isinstance(output_field_type, ts.FieldType) + # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` + node_type = output_field_type.dtype + field_dims.extend(output_field_type.dims) + field_shape.extend(output_desc.shape) + else: + assert isinstance(output_desc, dace.data.Scalar) + assert isinstance(output_field_type, ts.ScalarType) + # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` + node_type = output_field_type + + # allocate local temporary storage for the result field + temp_name, _ = sdfg.add_temp_transient( + field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + ) + field_node = state.add_access(temp_name) + field_type = ts.FieldType(field_dims, node_type) + + return field_node, field_type + + def __call__( + self, + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: SDFGBuilder, + ) -> list[TemporaryData]: + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + fun_node = node.fun + assert len(fun_node.args) == 2 + stencil_expr, domain_expr = fun_node.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, gtir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, gtir.FunCall) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + domain = dace_fieldview_util.get_domain(domain_expr) + + # first visit the list of arguments and build a symbol map + stencil_args = self._parse_node_args(node.args, sdfg, state, sdfg_builder, domain) # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.offset_provider) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -133,31 +216,12 @@ def __call__( last_node_connector = None # allocate local temporary storage for the result field - field_dims = [dim for dim, _, _ in field_domain] - field_shape = [ - # diff between upper and lower bound - (ub - lb) - for _, lb, ub in field_domain - ] - field_offset: Optional[list[dace.symbolic.SymbolicType]] = None - if any(lb != 0 for _, lb, _ in field_domain): - field_offset = [lb for _, lb, _ in field_domain] - if isinstance(output_desc, dace.data.Array): - raise NotImplementedError - else: - assert isinstance(output_expr.field_type, ts.ScalarType) - # TODO: enable `assert output_expr.field_type == node_type`, remove variable `dtype` - node_type = output_expr.field_type - - # TODO: use `field_type` directly, without passing through `dtype` - field_type = ts.FieldType(field_dims, node_type) - temp_name, _ = sdfg.add_temp_transient( - field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + field_node, field_type = self._create_temporary_field( + sdfg, state, domain, node_type, output_desc, output_expr.field_type ) - field_node = state.add_access(temp_name) # assume tasklet with single output - output_subset = [DIMENSION_INDEX_FMT.format(dim=dim.value) for dim, _, _ in field_domain] + output_subset = [IteratorIndexFmt.format(dim=dim.value) for dim, _, _ in domain] if isinstance(output_desc, dace.data.Array): # additional local dimension for neighbors assert set(output_desc.offset) == {0} @@ -165,7 +229,7 @@ def __call__( # create map range corresponding to the field operator domain map_ranges = { - DIMENSION_INDEX_FMT.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in field_domain + IteratorIndexFmt.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in domain } me, mx = state.add_map("field_op", map_ranges) @@ -193,28 +257,32 @@ def __call__( return [(field_node, field_type)] -class Select(PrimitiveTranslator): - """Generates the dataflow subgraph for the `select` builtin function.""" +class Cond(PrimitiveTranslator): + """Generates the dataflow subgraph for the `cond` builtin function.""" def __call__( self, - node: itir.Node, + node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: SDFGBuilder, - ) -> list[SDFGField]: - assert cpm.is_call_to(node, "select") - assert len(node.args) == 3 - cond_expr, true_expr, false_expr = node.args + ) -> list[TemporaryData]: + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "cond") + assert len(node.args) == 0 + + fun_node = node.fun + assert len(fun_node.args) == 3 + cond_expr, true_expr, false_expr = fun_node.args # expect condition as first argument - cond = dace_fieldview_util.get_symbolic_expr(cond_expr) + cond = gtir_python_codegen.get_source(cond_expr) # use current head state to terminate the dataflow, and add a entry state # to connect the true/false branch states as follows: # # ------------ - # === | select | === + # === | cond | === # || ------------ || # \/ \/ # ------------ ------------- @@ -225,19 +293,17 @@ def __call__( # ==> | head | <== # ------------ # - select_state = sdfg.add_state_before(state, state.label + "_select") - sdfg.remove_edge(sdfg.out_edges(select_state)[0]) + cond_state = sdfg.add_state_before(state, state.label + "_cond") + sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) # expect true branch as second argument true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) sdfg.add_edge(true_state, state, dace.InterstateEdge()) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge( - select_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")) - ) + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) sdfg.add_edge(false_state, state, dace.InterstateEdge()) true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) @@ -276,22 +342,22 @@ class SymbolRef(PrimitiveTranslator): def __call__( self, - node: itir.Node, + node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: SDFGBuilder, - ) -> list[SDFGField]: - assert isinstance(node, (itir.Literal, itir.SymRef)) + ) -> list[TemporaryData]: + assert isinstance(node, (gtir.Literal, gtir.SymRef)) data_type: ts.FieldType | ts.ScalarType - if isinstance(node, itir.Literal): + if isinstance(node, gtir.Literal): sym_value = node.value data_type = node.type tasklet_name = "get_literal" else: sym_value = str(node.id) - assert sym_value in sdfg_builder.symbol_types - data_type = sdfg_builder.symbol_types[sym_value] + assert sym_value in sdfg_builder.get_symbol_types() + data_type = sdfg_builder.get_symbol_types()[sym_value] tasklet_name = f"get_{sym_value}" if isinstance(data_type, ts.FieldType): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py index bdf6d401c1..3d99225d81 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py @@ -14,22 +14,33 @@ import dace -from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_to_sdfg as gtir_dace_translator, ) def build_sdfg_from_gtir( - program: itir.Program, - offset_provider: dict[str, Connectivity | Dimension], + program: gtir.Program, + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], ) -> dace.SDFG: """ - TODO: enable type inference - program = itir_type_inference.infer(program, offset_provider=offset_provider) + Receives a GTIR program and lowers it to a DaCe SDFG. + + The lowering to SDFG requires that the program node is type-annotated, therefore this function + runs type ineference as first step. + As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. + + Arguments: + program: The GTIR program node to be lowered to SDFG + offset_provider: The definitions of offset providers used by the program node + + Returns: + An SDFG in the DaCe canonical form (simplified) """ sdfg_genenerator = gtir_dace_translator.GTIRToSDFG(offset_provider) + # TODO: run type inference on the `program` node before passing it to `GTIRToSDFG` sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py new file mode 100644 index 0000000000..fe98d8a98a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -0,0 +1,126 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any + +import numpy as np + +from gt4py.eve import codegen +from gt4py.eve.codegen import FormatTemplate as as_fmt +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm + + +MATH_BUILTINS_MAPPING = { + "abs": "abs({})", + "sin": "math.sin({})", + "cos": "math.cos({})", + "tan": "math.tan({})", + "arcsin": "asin({})", + "arccos": "acos({})", + "arctan": "atan({})", + "sinh": "math.sinh({})", + "cosh": "math.cosh({})", + "tanh": "math.tanh({})", + "arcsinh": "asinh({})", + "arccosh": "acosh({})", + "arctanh": "atanh({})", + "sqrt": "math.sqrt({})", + "exp": "math.exp({})", + "log": "math.log({})", + "gamma": "tgamma({})", + "cbrt": "cbrt({})", + "isfinite": "isfinite({})", + "isinf": "isinf({})", + "isnan": "isnan({})", + "floor": "math.ifloor({})", + "ceil": "ceil({})", + "trunc": "trunc({})", + "minimum": "min({}, {})", + "maximum": "max({}, {})", + "fmod": "fmod({}, {})", + "power": "math.pow({}, {})", + "float": "dace.float64({})", + "float32": "dace.float32({})", + "float64": "dace.float64({})", + "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", + "int32": "dace.int32({})", + "int64": "dace.int64({})", + "bool": "dace.bool_({})", + "plus": "({} + {})", + "minus": "({} - {})", + "multiplies": "({} * {})", + "divides": "({} / {})", + "floordiv": "({} // {})", + "eq": "({} == {})", + "not_eq": "({} != {})", + "less": "({} < {})", + "less_equal": "({} <= {})", + "greater": "({} > {})", + "greater_equal": "({} >= {})", + "and_": "({} & {})", + "or_": "({} | {})", + "xor_": "({} ^ {})", + "mod": "({} % {})", + "not_": "(not {})", # ~ is not bitwise in numpy +} + + +def format_builtin(bultin: str, *args: Any) -> str: + if bultin in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[bultin] + else: + raise NotImplementedError(f"'{bultin}' not implemented.") + return fmt.format(*args) + + +class PythonCodegen(codegen.TemplatedGenerator): + """Helper class to visit a symbolic expression and translate it to Python code. + + The generated Python code can be use either as the body of a tasklet node or, + as in the case of field domain definitions, for sybolic array shape and map range. + """ + + SymRef = as_fmt("{id}") + Literal = as_fmt("{value}") + + def _visit_deref(self, node: gtir.FunCall) -> str: + assert len(node.args) == 1 + if isinstance(node.args[0], gtir.SymRef): + return self.visit(node.args[0]) + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + + def _visit_numeric_builtin(self, node: gtir.FunCall) -> str: + assert isinstance(node.fun, gtir.SymRef) + fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] + args = self.visit(node.args) + return fmt.format(*args) + + def visit_FunCall(self, node: gtir.FunCall) -> str: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + elif isinstance(node.fun, gtir.SymRef): + args = self.visit(node.args) + builtin_name = str(node.fun.id) + return format_builtin(builtin_name, *args) + raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + + +""" +Specialized visit method for symbolic expressions. + +Returns: + A string containing the Python code corresponding to a symbolic expression +""" +get_source = PythonCodegen.apply diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 7cd5c43f8a..f94bbc6953 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -21,18 +21,18 @@ import dace -from gt4py.next.common import Connectivity, Dimension, DimensionKind -from gt4py.next.iterator import ir as itir +from gt4py import eve +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, utility as dace_fieldview_util, ) -from gt4py.next.program_processors.runners.dace_fieldview.sdfg_builder import SDFGBuilder from gt4py.next.type_system import type_specifications as ts -class GTIRToSDFG(SDFGBuilder): +class GTIRToSDFG(eve.NodeVisitor, gtir_builtin_translators.SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation @@ -45,14 +45,24 @@ class GTIRToSDFG(SDFGBuilder): from where to continue building the SDFG. """ + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] + symbol_types: dict[str, ts.FieldType | ts.ScalarType] + def __init__( self, - offset_provider: dict[str, Connectivity | Dimension], + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], ): - super().__init__(offset_provider, symbol_types={}) + self.offset_provider = offset_provider + self.symbol_types = {} + + def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + return self.offset_provider + + def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + return self.symbol_types def _make_array_shape_and_strides( - self, name: str, dims: Sequence[Dimension] + self, name: str, dims: Sequence[gtx_common.Dimension] ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -68,7 +78,7 @@ def _make_array_shape_and_strides( shape = [ ( neighbor_tables[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL + if dim.kind == gtx_common.DimensionKind.LOCAL # we reuse the same symbol for field size passed as scalar argument to the gt4py program else dace.symbol(f"__{name}_size_{i}", dtype) ) @@ -100,7 +110,7 @@ def _add_storage(self, sdfg: dace.SDFG, name: str, symbol_type: ts.DataType) -> assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) self.symbol_types[name] = symbol_type - def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> dict[str, str]: + def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: """ Add temporary storage (aka transient) for data containers used as GTIR temporaries. @@ -109,7 +119,7 @@ def _add_storage_for_temporary(self, temp_decl: itir.Temporary) -> dict[str, str raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") def _visit_expression( - self, node: itir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState + self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState ) -> list[dace.nodes.AccessNode]: """ Specialized visit method for fieldview expressions. @@ -123,7 +133,7 @@ def _visit_expression( in case the transient arrays containing the expression result are not guaranteed to have the same memory layout as the target array. """ - results: list[gtir_builtin_translators.SDFGField] = self.visit( + results: list[gtir_builtin_translators.TemporaryData] = self.visit( node, sdfg=sdfg, head_state=head_state ) @@ -140,7 +150,7 @@ def _visit_expression( return field_nodes - def visit_Program(self, node: itir.Program) -> dace.SDFG: + def visit_Program(self, node: gtir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. First, it will allocate field and scalar storage for global data. The storage @@ -153,7 +163,7 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: raise NotImplementedError("Functions expected to be inlined as lambda calls.") sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_fieldview_util.debuginfo(node, sdfg.debuginfo) + sdfg.debuginfo = dace_fieldview_util.debug_info(node, default=sdfg.debuginfo) entry_state = sdfg.add_state("program_entry", is_start_block=True) # declarations of temporaries result in transient array definitions in the SDFG @@ -176,13 +186,13 @@ def visit_Program(self, node: itir.Program) -> dace.SDFG: for i, stmt in enumerate(node.body): # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - head_state._debuginfo = dace_fieldview_util.debuginfo(stmt, sdfg.debuginfo) + head_state._debuginfo = dace_fieldview_util.debug_info(stmt, default=sdfg.debuginfo) self.visit(stmt, sdfg=sdfg, state=head_state) sdfg.validate() return sdfg - def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + def visit_SetAt(self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) -> None: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. @@ -218,18 +228,20 @@ def visit_SetAt(self, stmt: itir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) ) def visit_FunCall( - self, node: itir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> list[gtir_builtin_translators.SDFGField]: + self, + node: gtir.FunCall, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.AsFieldOp(node.args)(node.fun, sdfg, head_state, self) - elif cpm.is_call_to(node.fun, "select"): - assert len(node.args) == 0 - return gtir_builtin_translators.Select()(node.fun, sdfg, head_state, self) + return gtir_builtin_translators.AsFieldOp()(node, sdfg, head_state, self) + elif cpm.is_call_to(node.fun, "cond"): + return gtir_builtin_translators.Cond()(node, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - def visit_Lambda(self, node: itir.Lambda) -> Any: + def visit_Lambda(self, node: gtir.Lambda) -> Any: """ This visitor class should never encounter `itir.Lambda` expressions because a lambda represents a stencil, which operates from iterator to values. @@ -238,11 +250,17 @@ def visit_Lambda(self, node: itir.Lambda) -> Any: raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.") def visit_Literal( - self, node: itir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> list[gtir_builtin_translators.SDFGField]: + self, + node: gtir.Literal, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) def visit_SymRef( - self, node: itir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState - ) -> list[gtir_builtin_translators.SDFGField]: + self, + node: gtir.SymRef, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> list[gtir_builtin_translators.TemporaryData]: return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 524553d174..0010bbb71e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -18,15 +18,15 @@ import dace import dace.subsets as sbs -import numpy as np from gt4py import eve -from gt4py.eve import codegen -from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_python_codegen, + utility as dace_fieldview_util, +) from gt4py.next.type_system import type_specifications as ts @@ -70,68 +70,13 @@ class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" field: dace.nodes.AccessNode - dimensions: list[Dimension] - indices: dict[Dimension, IteratorIndexExpr] + dimensions: list[gtx_common.Dimension] + indices: dict[gtx_common.Dimension, IteratorIndexExpr] INDEX_CONNECTOR_FMT = "__index_{dim}" -MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - class LambdaToTasklet(eve.NodeVisitor): """Translates an `ir.Lambda` expression to a dataflow graph. @@ -142,7 +87,7 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - offset_provider: dict[str, Connectivity | Dimension] + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -150,7 +95,7 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - offset_provider: dict[str, Connectivity | Dimension], + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], ): self.sdfg = sdfg self.state = state @@ -186,7 +131,7 @@ def _get_tasklet_result( ) return ValueExpr(temp_node, data_type) - def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: + def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr: assert len(node.args) == 1 it = self.visit(node.args[0]) @@ -205,19 +150,12 @@ def _visit_deref(self, node: itir.FunCall) -> MemletExpr | ValueExpr: assert isinstance(it, MemletExpr) return it - def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: + def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) else: - assert isinstance(node.fun, itir.SymRef) - - # create a tasklet node implementing the builtin function - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[builtin_name] - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") + assert isinstance(node.fun, gtir.SymRef) node_internals = [] node_connections: dict[str, MemletExpr | ValueExpr] = {} @@ -234,7 +172,8 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value node_internals.append(arg_expr.value) # use tasklet connectors as expression arguments - code = fmt.format(*node_internals) + builtin_name = str(node.fun.id) + code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) out_connector = "result" tasklet_node = self.state.add_tasklet( @@ -277,7 +216,7 @@ def visit_FunCall(self, node: itir.FunCall) -> IteratorExpr | MemletExpr | Value return self._get_tasklet_result(dtype, tasklet_node, "result") def visit_Lambda( - self, node: itir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] + self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] ) -> tuple[list[InputConnection], ValueExpr]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg @@ -303,45 +242,11 @@ def visit_Lambda( ) return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") - def visit_Literal(self, node: itir.Literal) -> SymbolExpr: + def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dtype = dace_fieldview_util.as_dace_type(node.type) return SymbolExpr(node.value, dtype) - def visit_SymRef(self, node: itir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: param = str(node.id) assert param in self.symbol_map return self.symbol_map[param] - - -class PythonCodegen(codegen.TemplatedGenerator): - """Helper class to visit a symbolic expression and translate it to Python code. - - The generated Python code can be use either as the body of a tasklet node or, - as in the case of field domain definitions, for sybolic array shape and map range. - """ - - SymRef = as_fmt("{id}") - Literal = as_fmt("{value}") - - def _visit_deref(self, node: itir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], itir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_numeric_builtin(self, node: itir.FunCall) -> str: - assert isinstance(node.fun, itir.SymRef) - fmt = MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = self.visit(node.args) - return fmt.format(*args) - - def visit_FunCall(self, node: itir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - elif isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py b/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py deleted file mode 100644 index dafddbcd1a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/sdfg_builder.py +++ /dev/null @@ -1,29 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -""" -Visitor interface to build an SDFG dataflow. - -""" - -from dataclasses import dataclass - -from gt4py import eve -from gt4py.next.common import Connectivity, Dimension -from gt4py.next.type_system import type_specifications as ts - - -@dataclass(frozen=True) -class SDFGBuilder(eve.NodeVisitor): - offset_provider: dict[str, Connectivity | Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 9a20d06487..e829feaf2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,10 +16,10 @@ import dace -from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet +from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts @@ -49,8 +49,8 @@ def as_scalar_type(typestr: str) -> ts.ScalarType: return ts.ScalarType(kind) -def debuginfo( - node: itir.Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +def debug_info( + node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None ) -> Optional[dace.dtypes.DebugInfo]: location = node.location if location: @@ -61,10 +61,10 @@ def debuginfo( end_column=location.end_column if location.end_column else 0, filename=location.filename, ) - return debuginfo + return default -def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: +def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, gtx_common.Connectivity]: """ Filter offset providers of type `Connectivity`. @@ -74,13 +74,13 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne return { offset: table for offset, table in offset_provider.items() - if isinstance(table, Connectivity) + if isinstance(table, gtx_common.Connectivity) } def get_domain( - node: itir.Expr, -) -> list[tuple[Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + node: gtir.Expr, +) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Specialized visit method for domain expressions. @@ -93,34 +93,24 @@ def get_domain( assert cpm.is_call_to(named_range, "named_range") assert len(named_range.args) == 3 axis = named_range.args[0] - assert isinstance(axis, itir.AxisLiteral) + assert isinstance(axis, gtir.AxisLiteral) bounds = [] for arg in named_range.args[1:3]: - sym_str = get_symbolic_expr(arg) + sym_str = gtir_python_codegen.get_source(arg) sym_val = dace.symbolic.SymExpr(sym_str) bounds.append(sym_val) - dim = Dimension(axis.value, axis.kind) + dim = gtx_common.Dimension(axis.value, axis.kind) domain.append((dim, bounds[0], bounds[1])) return domain def get_domain_ranges( - node: itir.Expr, -) -> dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: + node: gtir.Expr, +) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: """ Returns domain represented in dictionary form. """ domain = get_domain(node) return {dim: (lb, ub) for dim, lb, ub in domain} - - -def get_symbolic_expr(node: itir.Expr) -> str: - """ - Specialized visit method for symbolic expressions. - - Returns a string containong the corresponding Python code, which as tasklet body - or symbolic array shape. - """ - return gtir_to_tasklet.PythonCodegen().visit(node) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 03f71e08fb..d674e71f0b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -17,7 +17,9 @@ Note: this test module covers the fieldview flavour of ITIR. """ -from gt4py.next.iterator import ir as itir +import copy +from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts @@ -47,19 +49,19 @@ def test_gtir_copy(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( + testee = gtir.Program( id="gtir_copy", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="y", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=im.call( im.call("as_fieldop")( im.lambda_("a")(im.deref("a")), @@ -67,7 +69,7 @@ def test_gtir_copy(): ) )("x"), domain=domain, - target=itir.SymRef(id="y"), + target=gtir.SymRef(id="y"), ) ], ) @@ -83,7 +85,7 @@ def test_gtir_copy(): def test_gtir_update(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) stencil1 = im.call( im.call("as_fieldop")( @@ -99,19 +101,19 @@ def test_gtir_update(): )("x", 1.0) for i, stencil in enumerate([stencil1, stencil2]): - testee = itir.Program( + testee = gtir.Program( id=f"gtir_update_{i}", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=stencil, domain=domain, - target=itir.SymRef(id="x"), + target=gtir.SymRef(id="x"), ) ], ) @@ -126,20 +128,20 @@ def test_gtir_update(): def test_gtir_sum2(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( + testee = gtir.Program( id="sum_2fields", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="y", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -147,7 +149,7 @@ def test_gtir_sum2(): ) )("x", "y"), domain=domain, - target=itir.SymRef(id="z"), + target=gtir.SymRef(id="z"), ) ], ) @@ -164,19 +166,19 @@ def test_gtir_sum2(): def test_gtir_sum2_sym(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( + testee = gtir.Program( id="sum_2fields_sym", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -184,7 +186,7 @@ def test_gtir_sum2_sym(): ) )("x", "x"), domain=domain, - target=itir.SymRef(id="z"), + target=gtir.SymRef(id="z"), ) ], ) @@ -200,7 +202,7 @@ def test_gtir_sum2_sym(): def test_gtir_sum3(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) stencil1 = im.call( im.call("as_fieldop")( @@ -230,22 +232,22 @@ def test_gtir_sum3(): c = np.random.rand(N) for i, stencil in enumerate([stencil1, stencil2]): - testee = itir.Program( + testee = gtir.Program( id=f"sum_3fields_{i}", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="y", type=IFTYPE), - itir.Sym(id="w", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="w", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=stencil, domain=domain, - target=itir.SymRef(id="z"), + target=gtir.SymRef(id="z"), ) ], ) @@ -258,25 +260,25 @@ def test_gtir_sum3(): assert np.allclose(d, (a + b + c)) -def test_gtir_select(): +def test_gtir_cond(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( - id="select_2sums", + testee = gtir.Program( + id="cond_2sums", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="y", type=IFTYPE), - itir.Sym(id="w", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="cond", type=ts.ScalarType(ts.ScalarKind.BOOL)), - itir.Sym(id="scalar", type=ts.ScalarType(ts.ScalarKind.FLOAT64)), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="w", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="scalar", type=ts.ScalarType(ts.ScalarKind.FLOAT64)), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -285,8 +287,8 @@ def test_gtir_select(): )( "x", im.call( - im.call("select")( - im.deref("cond"), + im.call("cond")( + im.deref("pred"), im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -303,7 +305,7 @@ def test_gtir_select(): )(), ), domain=domain, - target=itir.SymRef(id="z"), + target=gtir.SymRef(id="z"), ) ], ) @@ -316,30 +318,30 @@ def test_gtir_select(): for s in [False, True]: d = np.empty_like(a) - sdfg(cond=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) + sdfg(pred=np.bool_(s), scalar=1.0, x=a, y=b, w=c, z=d, **FSYMBOLS) assert np.allclose(d, (a + b + 1) if s else (a + c + 1)) -def test_gtir_select_nested(): +def test_gtir_cond_nested(): domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - testee = itir.Program( - id="select_nested", + testee = gtir.Program( + id="cond_nested", function_definitions=[], params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="cond_1", type=ts.ScalarType(ts.ScalarKind.BOOL)), - itir.Sym(id="cond_2", type=ts.ScalarType(ts.ScalarKind.BOOL)), - itir.Sym(id="size", type=SIZE_TYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="z", type=IFTYPE), + gtir.Sym(id="pred_1", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="pred_2", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ - itir.SetAt( + gtir.SetAt( expr=im.call( - im.call("select")( - im.deref("cond_1"), + im.call("cond")( + im.deref("pred_1"), im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -347,8 +349,8 @@ def test_gtir_select_nested(): ) )("x", 1), im.call( - im.call("select")( - im.deref("cond_2"), + im.call("cond")( + im.deref("pred_2"), im.call( im.call("as_fieldop")( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), @@ -366,7 +368,7 @@ def test_gtir_select_nested(): ) )(), domain=domain, - target=itir.SymRef(id="z"), + target=gtir.SymRef(id="z"), ) ], ) @@ -378,5 +380,5 @@ def test_gtir_select_nested(): for s1 in [False, True]: for s2 in [False, True]: b = np.empty_like(a) - sdfg(cond_1=np.bool_(s1), cond_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) + sdfg(pred_1=np.bool_(s1), pred_2=np.bool_(s2), x=a, z=b, **FSYMBOLS) assert np.allclose(b, (a + 1.0) if s1 else (a + 2.0) if s2 else (a + 3.0)) From 1df1bc3bbeec41b49366a9de838e7d4636305862 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 10 Jul 2024 12:02:00 +0200 Subject: [PATCH 128/235] Apply convention for map variables --- .../dace_fieldview/gtir_builtin_translators.py | 13 +++++-------- .../runners/dace_fieldview/gtir_to_tasklet.py | 3 --- .../runners/dace_fieldview/utility.py | 8 ++++++++ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 1a95fb3d01..db40786644 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -14,7 +14,7 @@ import abc -from typing import Any, Final, Optional, Protocol, TypeAlias +from typing import Any, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs @@ -31,11 +31,8 @@ from gt4py.next.type_system import type_specifications as ts -# Define aliases for return types -TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] - -IteratorIndexFmt: Final[str] = "i_{dim}" IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes +TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] class SDFGBuilder(Protocol): @@ -114,7 +111,7 @@ def _parse_node_args( assert isinstance(arg_type, ts.FieldType) indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { dim: gtir_to_tasklet.SymbolExpr( - dace.symbolic.SymExpr(IteratorIndexFmt.format(dim=dim.value)), + dace_fieldview_util.get_map_variable(dim), IteratorIndexDType, ) for dim, _, _ in domain @@ -221,7 +218,7 @@ def __call__( ) # assume tasklet with single output - output_subset = [IteratorIndexFmt.format(dim=dim.value) for dim, _, _ in domain] + output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] if isinstance(output_desc, dace.data.Array): # additional local dimension for neighbors assert set(output_desc.offset) == {0} @@ -229,7 +226,7 @@ def __call__( # create map range corresponding to the field operator domain map_ranges = { - IteratorIndexFmt.format(dim=dim.value): f"{lb}:{ub}" for dim, lb, ub in domain + dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain } me, mx = state.add_map("field_op", map_ranges) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 0010bbb71e..25a80892d0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -74,9 +74,6 @@ class IteratorExpr: indices: dict[gtx_common.Dimension, IteratorIndexExpr] -INDEX_CONNECTOR_FMT = "__index_{dim}" - - class LambdaToTasklet(eve.NodeVisitor): """Translates an `ir.Lambda` expression to a dataflow graph. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index e829feaf2d..7e9bec2545 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -114,3 +114,11 @@ def get_domain_ranges( domain = get_domain(node) return {dim: (lb, ub) for dim, lb, ub in domain} + + +def get_map_variable(dim: gtx_common.Dimension) -> str: + """ + Format map variable name based on the naming convention for application-specific SDFG transformations. + """ + suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" + return f"i_{dim.value}_gtx_{dim.kind}{suffix}" From 6394243548814eed29846f19794c4b19baddc01e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 10 Jul 2024 13:32:29 +0200 Subject: [PATCH 129/235] Updated and fixed a big in the `is_interstate_transient()` function. There was a bug before, as the set was computed wrongly, it sued the sink instead of the source nodes. However, I have updated and renamed the function, it now works differently, i.e. not if we can remove it, but if we have to keep it. Also after the discussion with Hannes and Edoardo I refined the definition of the set. --- .../transformations/map_fusion_helper.py | 231 +++++------------- .../dace_fieldview/transformations/util.py | 140 +++++++++++ 2 files changed, 195 insertions(+), 176 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 0b7b9f6760..65b92fc3b7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -15,7 +15,8 @@ """Implements Helper functionaliyies for map fusion""" import functools -from typing import Any, Iterable, Optional, Sequence, Union +import itertools +from typing import Any, Optional, Sequence, Union import dace from dace import subsets @@ -42,6 +43,10 @@ class MapFusionHelper(dace_transformation.SingleStateTransformation): After every transformation that manipulates the state machine, you shouls recreate the transformation. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. """ only_toplevel_maps = properties.Property( @@ -62,7 +67,7 @@ class MapFusionHelper(dace_transformation.SingleStateTransformation): default=None, allow_none=True, desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'can_transient_be_removed()'.", + "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", ) def __init__( @@ -82,24 +87,24 @@ def __init__( def expressions(cls) -> bool: raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") - def can_be_applied( + def can_be_fused( self, map_entry_1: nodes.MapEntry, map_entry_2: nodes.MapEntry, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, ) -> bool: """Performs some checks if the maps can be fused. - This function does not follow the standard interface of DaCe transformations. - Instead it checks if the two maps can be fused, by comparing: + Essentially, this function only checks constrains that does not depend if + a serial or a parallel map fusion happens. Thus it tests: - The scope of the maps. - The scheduling of the maps. - The map parameters. - However, for performance reasons, the function does not compute the node - partition. + However, for performance reasons, the function does not check if the node + decomposition exists. """ if self.only_inner_maps and self.only_toplevel_maps: @@ -249,49 +254,61 @@ def map_parameter_compatible( return True - def can_transient_be_removed( + def is_interstate_transient( self, transient: Union[str, nodes.AccessNode], - sdfg: SDFG, + sdfg: dace.SDFG, ) -> bool: - """Can `transient` be removed. + """Tests if `transient` is an interstate transient, an can not be removed. - Essentially the function tests if the transient `transient` is needed to - transmit information from one state to the other. The function will first - look consult `self.shared_transients`, if the SDFG is not known the function - will compute the set of transients that have to be kept alive. - - If `transient` refers to a scalar the function will return `False`, as - a scalar can not be removed. + Essentially this function checks if a transient is needed in a + different state in the SDFG, because it transmit information from + one state to the other. However, this function only checks if the + transient is needed for transmitting information between states. + It does _not_ check if the transient is needed multiple times within + the state. This case can be detected by checking the number of outgoing + edges. Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. """ - if sdfg not in self.shared_transients: - # SDFG is not known, so we have to compute the set of all transients that - # have to be kept alive. This set is given by all transients that are - # source nodes; We currently ignore scalars. - shared_sdfg_transients: set[str] = set() + # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + # the set of such transients is partially given by all source access nodes. + # Because of rule 3 we also include all scalars in this set, as an over + # approximation. Furthermore, because simplify might violate rule 3, + # we also include the sink nodes. + + # See if we have already computed the set + if sdfg in self.shared_transients: + shared_sdfg_transients: set[str] = self.shared_transients[sdfg] + + else: + # SDFG is not known so we have to compute it. + # If a scalar is not a source node then it is not included in this set. + # Thus we do not have to look for it, instead we will check for them + # explicitly. + shared_sdfg_transients = set() for state in sdfg.states(): - for acnode in filter( - lambda node: isinstance(node, nodes.AccessNode), state.sink_nodes() - ): - desc = sdfg.arrays[acnode.data] - if desc.transient and isinstance(desc, data.Array): - shared_sdfg_transients.add(acnode.data) + shared_sdfg_transients.update( + filter( + lambda node: isinstance(node, nodes.AccessNode) + and sdfg.arrays[node.data].transient, + itertools.chain(state.source_nodes(), state.sink_nodes()), + ) + ) self.shared_transients[sdfg] = shared_sdfg_transients if isinstance(transient, nodes.AccessNode): transient = transient.data + desc: data.Data = sdfg.arrays[transient] - desc: data.Data = sdfg.arrays[transient] # type: ignore[no-redef] - if isinstance(desc, data.View): + if not desc.transient: return False if isinstance(desc, data.Scalar): return False - return transient not in self.shared_transients[sdfg] + return transient not in shared_sdfg_transients def partition_first_outputs( self, @@ -422,7 +439,7 @@ def partition_first_outputs( # implementation in the actual fusion routine. # This is an assumption that is in most cases correct, but not always. # However, doing it correctly is extremely complex. - for _, produce_edge in self.find_upstream_producers(state, out_edge): + for _, produce_edge in util.find_upstream_producers(state, out_edge): if produce_edge.data.wcr is not None: return None @@ -442,7 +459,7 @@ def partition_first_outputs( # them iff the node does not consume that whole intermediate. # Furthermore, it can not be a dynamic map range. intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - consumers = self.find_downstream_consumers(state=state, begin=intermediate_node) + consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, feed_edge in consumers: # TODO(phimuell): Improve this approximation. if feed_edge.data.num_elements() == intermediate_size: @@ -454,10 +471,10 @@ def partition_first_outputs( # output of the check function, from within the second map we remove # the intermediate, it has more the meaning of "do we need to # reconstruct it after the second map again?". - if self.can_transient_be_removed(intermediate_node, sdfg): - exclusive_outputs.add(out_edge) - else: + if self.is_interstate_transient(intermediate_node, sdfg): shared_outputs.add(out_edge) + else: + exclusive_outputs.add(out_edge) continue else: @@ -475,7 +492,7 @@ def partition_first_outputs( if found_second_entry: # The second map was found again. return None found_second_entry = True - consumers = self.find_downstream_consumers(state=state, begin=edge) + consumers = util.find_downstream_consumers(state=state, begin=edge) for consumer_node, feed_edge in consumers: if feed_edge.data.num_elements() == intermediate_size: return None @@ -494,141 +511,3 @@ def partition_first_outputs( assert exclusive_outputs or shared_outputs or pure_outputs return (pure_outputs, exclusive_outputs, shared_outputs) - - def all_nodes_between( - self, - graph: SDFG | SDFGState, - begin: nodes.Node, - end: nodes.Node, - reverse: bool = False, - ) -> set[nodes.Node] | None: - """Returns all nodes that are reachable from `begin` but bound by `end`. - - What the function does is, that it starts a DFS starting at `begin`, which is - not part of the returned set, every edge that goes to `end` will be considered - to not exists. - In case `end` is never found the function will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. - - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will never contain the node `begin`. - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. - """ - - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.dst for edge in graph.out_edges(node)) - - if reverse: - begin, end = end, begin - - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.src for edge in graph.in_edges(node)) - - to_visit: list[nodes.Node] = [begin] - seen: set[nodes.Node] = set() - found_end: bool = False - - while len(to_visit) > 0: - n: nodes.Node = to_visit.pop() - if n == end: - found_end = True - continue - elif n in seen: - continue - seen.add(n) - to_visit.extend(next_nodes(n)) - - if not found_end: - return None - - seen.discard(begin) - return seen - - def find_downstream_consumers( - self, - state: SDFGState, - begin: nodes.Node | nodes.MultiConnectorEdge, - only_tasklets: bool = False, - reverse: bool = False, - ) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: - """Find all downstream connectors of `begin`. - - A consumer, in this sense, is any node that is neither an entry nor an exit - node. The function returns a set storing the pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. - - Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. - """ - if isinstance(begin, nodes.MultiConnectorEdge): - to_visit: list[nodes.MultiConnectorEdge] = [begin] - elif reverse: - to_visit = list(state.in_edges(begin)) - else: - to_visit = list(state.out_edges(begin)) - seen: set[nodes.MultiConnectorEdge] = set() - found: set[tuple[nodes.Node, nodes.MultiConnectorEdge]] = set() - - while len(to_visit) != 0: - curr_edge: nodes.MultiConnectorEdge = to_visit.pop() - next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): - if reverse: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - else: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, nodes.MapEntry - ): - # This edge defines a dynamic map range, which is a consumer - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - to_visit.extend(new_edges) - else: - if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): - continue - found.add((next_node, curr_edge)) - - return found - - def find_upstream_producers( - self, - state: SDFGState, - begin: nodes.Node | nodes.MultiConnectorEdge, - only_tasklets: bool = False, - ) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return self.find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index b549e2e76e..aa714142c3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -14,7 +14,10 @@ """Common functionality for the transformations.""" +from typing import Iterable + import dace +from dace.sdfg import nodes def is_nested_sdfg(sdfg: dace.SDFG) -> bool: @@ -29,3 +32,140 @@ def is_nested_sdfg(sdfg: dace.SDFG) -> bool: return False else: raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") + + +def all_nodes_between( + graph: dace.SDFG | dace.SDFGState, + begin: nodes.Node, + end: nodes.Node, + reverse: bool = False, +) -> set[nodes.Node] | None: + """Returns all nodes that are reachable from `begin` but bound by `end`. + + What the function does is, that it starts a DFS starting at `begin`, which is + not part of the returned set, every edge that goes to `end` will be considered + to not exists. + In case `end` is never found the function will return `None`. + + If `reverse` is set to `True` the function will start exploring at `end` and + follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The terminator node of the DFS. + reverse: Perform a backward DFS. + + Notes: + - The returned set will never contain the node `begin`. + - The returned set will also contain the nodes of path that starts at + `begin` and ends at a node that is not `end`. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + if reverse: + return (edge.src for edge in graph.in_edges(node)) + return (edge.dst for edge in graph.out_edges(node)) + + if reverse: + begin, end = end, begin + + to_visit: list[nodes.Node] = [begin] + seen: set[nodes.Node] = set() + found_end: bool = False + + while len(to_visit) > 0: + n: nodes.Node = to_visit.pop() + if n == end: + found_end = True + continue + elif n in seen: + continue + seen.add(n) + to_visit.extend(next_nodes(n)) + + if not found_end: + return None + + seen.discard(begin) + return seen + + +def find_downstream_consumers( + state: dace.SDFGState, + begin: nodes.Node | nodes.MultiConnectorEdge, + only_tasklets: bool = False, + reverse: bool = False, +) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: + """Find all downstream connectors of `begin`. + + A consumer, in this sense, is any node that is neither an entry nor an exit + node. The function returns a set storing the pairs, the first element is the + node that acts as consumer and the second is the edge that leads to it. + By setting `only_tasklets` the nodes the function finds are only Tasklets. + + To find this set the function starts a search at `begin`, however, it is also + possible to pass an edge as `begin`. + If `reverse` is `True` the function essentially finds the producers that are + upstream. + + Args: + state: The state in which to look for the consumers. + begin: The initial node that from which the search starts. + only_tasklets: Return only Tasklets. + reverse: Follow the reverse direction. + """ + if isinstance(begin, nodes.MultiConnectorEdge): + to_visit: list[nodes.MultiConnectorEdge] = [begin] + elif reverse: + to_visit = list(state.in_edges(begin)) + else: + to_visit = list(state.out_edges(begin)) + seen: set[nodes.MultiConnectorEdge] = set() + found: set[tuple[nodes.Node, nodes.MultiConnectorEdge]] = set() + + while len(to_visit) != 0: + curr_edge: nodes.MultiConnectorEdge = to_visit.pop() + next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst + + if curr_edge in seen: + continue + seen.add(curr_edge) + + if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): + if reverse: + target_conn = curr_edge.src_conn[4:] + new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) + else: + # In forward mode a Map entry could also mean the definition of a + # dynamic map range. + if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( + next_node, nodes.MapEntry + ): + # This edge defines a dynamic map range, which is a consumer + if not only_tasklets: + found.add((next_node, curr_edge)) + continue + target_conn = curr_edge.dst_conn[3:] + new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) + to_visit.extend(new_edges) + else: + if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): + continue + found.add((next_node, curr_edge)) + + return found + + +def find_upstream_producers( + state: dace.SDFGState, + begin: nodes.Node | nodes.MultiConnectorEdge, + only_tasklets: bool = False, +) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: + """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" + return find_downstream_consumers( + state=state, + begin=begin, + only_tasklets=only_tasklets, + reverse=True, + ) From fcd8ee354244da2cf756a7e691af5fd304e160b6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 09:01:03 +0200 Subject: [PATCH 130/235] Small corrections and format improvements. --- .../transformations/map_fusion_helper.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 65b92fc3b7..b5aa3a0071 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -19,22 +19,15 @@ from typing import Any, Optional, Sequence, Union import dace -from dace import subsets -from dace.sdfg import ( - SDFG, - SDFGState, - data, - nodes, - properties, - transformation as dace_transformation, -) +from dace import data, properties, subsets, transformation +from dace.sdfg import SDFG, SDFGState, nodes from dace.transformation import helpers from . import util @properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): +class MapFusionHelper(transformation.SingleStateTransformation): """ Contains common part of the map fusion for parallel and serial map fusion. @@ -471,6 +464,7 @@ def partition_first_outputs( # output of the check function, from within the second map we remove # the intermediate, it has more the meaning of "do we need to # reconstruct it after the second map again?". + # NOTE: The case "used in this state" is handled above!! if self.is_interstate_transient(intermediate_node, sdfg): shared_outputs.add(out_edge) else: From a25a6a476d662ee70d68b1e8fad758542601a287 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 09:01:54 +0200 Subject: [PATCH 131/235] Fixed some missing include. --- .../next/program_processors/runners/dace_fieldview/utility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index f3ce3d5dd2..bdcd268f66 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,7 +16,7 @@ import dace -from gt4py.next.common import Connectivity, Dimension +from gt4py.next.common import Connectivity, Dimension, DimensionKind from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_tasklet From a57e108cbaacbe2ddb8adf80cd0ae70b4b387899 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 09:23:24 +0200 Subject: [PATCH 132/235] Added a first and mostly untested version of the serial fusion transformation. --- .../transformations/map_seriall_fusion.py | 470 ++++++++++++++++++ 1 file changed, 470 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py new file mode 100644 index 0000000000..3eaebab634 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -0,0 +1,470 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Implements the seriall map fusing transformation.""" + +import copy +from typing import Any, Union + +import dace +from dace import dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, nodes + +from . import map_fusion_helper + + +@properties.make_properties +class SerialMapFusion(map_fusion_helper.MapFusionHelper): + """Specialized replacement for the map fusion transformation that is provided by DaCe. + + As its name is indicating this transformation is only able to handle Maps that + are in sequence. Compared to the native DaCe transformation, this one is able + to handle more complex cases of connection between the maps. In that sense, it + is much more similar to DaCe's `SubgraphFusion` transformation. + + Things that are improved, compared to the native DaCe implementation: + - Nested Maps. + - Temporary arrays and the correct propagation of their Memelts. + - Top Maps that have multiple outputs. + + Conceptually this transformation removes the exit of the first or upper map + and the entry of the lower or second map and then rewriting the connections + appropriate. + + This transformation assumes that an SDFG obeys the structure that is outlined + [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that + reason it is not true replacement of the native DaCe transformation. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Notes: + This transformation modifies more nodes than it matches! + """ + + map_exit1 = transformation.PatternNode(nodes.MapExit) + access_node = transformation.PatternNode(nodes.AccessNode) + map_entry2 = transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on these nodes, + but more or less anything that has an outgoing connection of the first Map. + """ + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + # end def: expressions + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - The `can_be_fused()` of the base succeed, which checks some basic constrains. + - The decomposition exists and at least one of the intermediate sets + is not empty. + """ + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + map_entry_2: nodes.MapEntry = self.map_entry2 + + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # there at least one of the intermediate output sets is not empty. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + ) + if output_partition is None: + return False + if not (output_partition[1] or output_partition[2]): + return False + + return True + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + """ + assert isinstance(graph, dace.SDFGState) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + # Must be here to prevent errors + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + + # Handling the outputs + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(self.map_exit1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=self.map_exit1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=self.map_entry2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [self.map_exit1, self.map_entry2]: + assert graph.degree(node_to_remove) + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + def handle_intermediate_set( + self, + intermediate_outputs: set[nodes.MultiConnectorEdge], + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + is_exclusive_set: bool, + ) -> None: + """Handle the intermediate output sets. + + The function is able to handle both the shared and exclusive intermediate + output set that was computed by `partition_first_outputs()`. Which one is + handled is indicated by `is_exclusive_set`. + + The main difference is that in exclusive mode is that the intermediate node + will be fully removed from the SDFG. However, in shared mode, the intermediate + + Args: + intermediateOutputSet: The set of edges that are intermediate outputs of the first Map. + graph: The graph we operate on. + sdfg: The SDFG we operate on. + mapExit1: The node that serves as exit node of the first Map. + mapEntry2: The node that serves as entry node of the second Map. + mapExit2: The node that serves as exit node of the second Map. + isExclusiveSet: If `True` process the exclusive set. + """ + + # Essentially this function removes the AccessNode between the two maps. + # However, we still need some temporary memory that we can use, which is + # just much smaller, i.e. a scalar. But all Memlets inside the second map + # assumes that the intermediate memory has the bigger shape. + # To fix that we will create this replacement dict that will replace all + # occurrences of the iteration variables of the second map with zero. + # Note that this is still not enough as the dimensionality might be different. + memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} + + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Now we will iterate over all intermediate edge and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, depending on the mode, we want + # to get rid off, in shared mode it materialize, but now at the + # exit of the second Map, in exclusive mode it will be removed. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate, which has some + # issue. The size of this temporary is given by the Memlet that + # goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + assert pre_exit_edges[0].dst_conn is None + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # That are known to cause some troubles, so we will now remove them. + squeezed_dims: list[ + int + ] = [] # Dimensions of the original intermediate we squeezed away. + new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + # Order of checks is important. + if ( + full_dim_size == 1 + ): # The original array had dimension size 1, so we have to keep it. + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + isScalar: bool = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # New we will reroute the output Memlet, thus it will no longer going + # through the Map exit but through the newly created intermediate. + # we will delete the previous edge later. + pre_exit_memlet: dace.Memlet = pre_exit_edge.data + new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) + + # We might operate on a different array, but the check below, ensures + # that we do not change the direction of the Memlet. + assert pre_exit_memlet.data == inter_name + new_pre_exit_memlet.data = new_inter_name + + # Now we have to fix the subset of the Memlet. + # Before the subset of the Memlet dependent on the Map variables, + # however, this is no longer the case, as we removed them. This change + # has to be reflected in the Memlet. + # NOTE: Assert above ensures that the bellow is correct. + new_pre_exit_memlet.replace(memlet_repl) + if is_scalar: + new_pre_exit_memlet.subset = "0" + new_pre_exit_memlet.other_subset = None + else: + new_pre_exit_memlet.subset.pop(squeezed_dims) + + # Now we create the new edge between the producer and the new output + # (the new intermediate node). We will remove the old edge further down. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We just have handled the last Memlet, but we must actually handle the + # whole producer side, i.e. the scope of the top Map. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): + producer_edge = producer_tree.edge + + # Ensure the correctness of the rerouting below. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + + # Will not change the direction, because of test above! + producer_edge.data.data = new_inter_name + producer_edge.data.replace(memlet_repl) + if is_scalar: + producer_edge.data.dst_subset = "0" + else: + if producer_edge.data.dst_subset is not None: + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the temporary + # in the second map. We do this by finding the input connectors on the + # map entry, such that we know where we have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: list[str] = [] + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.append(inter_node_out_edge.dst_conn) + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # The create the first Memlet to transmit information, within + # the second map, we do this again by copying and modifying + # the original Memlet. + # NOTE: Test above is important to ensure the direction of the + # Memlet and the correctness of the code below. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.replace(memlet_repl) + new_inner_memlet.data = new_inter_name # Because of the assert above, this will not chenge the direction. + + # Now remove the old edge, that started the second map entry. + # Also add the new edge that started at the new intermediate. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now we do subset modification to ensure that nothing failed. + if isScalar: + new_inner_memlet.src_subset = "0" + else: + if new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now clean the Memlets of that tree to use the new intermediate node. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): + consumer_edge = consumer_tree.edge + assert consumer_edge.data.data == inter_name + consumer_edge.data.data = new_inter_name + if isScalar: + consumer_edge.data.src_subset = "0" + else: + if consumer_edge.data.subset is not None: + consumer_edge.data.subset.pop(squeezed_dims) + + # Now remove the inner edge, as it is no longer needed. + state.remove_edge(inner_edge) + + # Now remove the connectors from the map entry. Note that the + # edge ending in `out_conn_name` is still connected to the node. + map_entry_2.remove_out_connector(out_conn_name) + map_entry_2.remove_in_connector(in_conn_name) + + # TODO: Apply this modification to Memlets + # for neighbor in state.all_edges(local_node): + # for e in state.memlet_tree(neighbor): + # if e.data.data == local_name: + # continue # noqa: ERA001 + # e.data.data = local_name # noqa: ERA001 + # e.data.subset.offset(old_edge.data.subset, negative=True) # noqa: ERA001 + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate temporary node to the Map output. + # This will essentially restore or preserve the output for the intermediate node. + # It is important that we use the data that `preExitEdge` was used. + # On CPU it works but for some reasons not on GPU. + new_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert new_exit_memlet.data == inter_name + new_exit_memlet.subset = pre_exit_edge.data.dst_subset + new_exit_memlet.other_subset = ( + "0" if is_scalar else subsets.Range.from_array(inter_desc) + ) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + new_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) From 62ad165b193d758cad44e337991940056967de0c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 11:56:43 +0200 Subject: [PATCH 133/235] Started debugin, very strange bug. --- .../transformations/map_fusion_helper.py | 31 ++++--- .../transformations/map_seriall_fusion.py | 89 ++++++++++--------- .../dace_fieldview/transformations/util.py | 20 ++--- 3 files changed, 74 insertions(+), 66 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index b5aa3a0071..32ef6a6be5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -20,7 +20,7 @@ import dace from dace import data, properties, subsets, transformation -from dace.sdfg import SDFG, SDFGState, nodes +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes from dace.transformation import helpers from . import util @@ -128,8 +128,8 @@ def can_be_fused( return True + @staticmethod def relocate_nodes( - self, from_node: Union[nodes.MapExit, nodes.MapEntry], to_node: Union[nodes.MapExit, nodes.MapEntry], state: SDFGState, @@ -191,14 +191,13 @@ def relocate_nodes( # We have a Passthrough connection, i.e. there exists a `OUT_` connector # thus we now have to migrate the two edges. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix new_conn = to_node.next_connector(old_conn) for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="OUT_" + new_conn) + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) from_node.remove_in_connector("IN_" + old_conn) from_node.remove_out_connector("OUT_" + old_conn) @@ -207,7 +206,7 @@ def relocate_nodes( ), f"After moving source node '{from_node}' still has an input degree of {state.in_degree(from_node)}" assert ( state.out_degree(from_node) == 0 - ), f"After moving source node '{from_node}' still has an output degree of {state.in_degree(from_node)}" + ), f"After moving source node '{from_node}' still has an output degree of {state.out_degree(from_node)}" def map_parameter_compatible( self, @@ -301,7 +300,7 @@ def is_interstate_transient( return False if isinstance(desc, data.Scalar): return False - return transient not in shared_sdfg_transients + return transient in shared_sdfg_transients def partition_first_outputs( self, @@ -311,9 +310,9 @@ def partition_first_outputs( map_entry_2: nodes.MapEntry, ) -> Union[ tuple[ - set[nodes.MultiConnectorEdge], - set[nodes.MultiConnectorEdge], - set[nodes.MultiConnectorEdge], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], ], None, ]: @@ -348,9 +347,9 @@ def partition_first_outputs( map_entry_2: The entry node of the second map. """ # The three outputs set. - pure_outputs: set[nodes.MultiConnectorEdge] = set() - exclusive_outputs: set[nodes.MultiConnectorEdge] = set() - shared_outputs: set[nodes.MultiConnectorEdge] = set() + pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() # Set of intermediate nodes that we have already processed. processed_inter_nodes: set[nodes.Node] = set() @@ -372,7 +371,7 @@ def partition_first_outputs( # Now let's look at all nodes that are downstream of the intermediate node. # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = self.all_nodes_between( + downstream_nodes = util.all_nodes_between( graph=state, begin=intermediate_node, end=map_entry_2, @@ -401,8 +400,8 @@ def partition_first_outputs( # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. - inner_collector_edges = state.in_edges_by_connector( - intermediate_node, "IN_" + out_edge.src_conn[3:] + inner_collector_edges = list( + state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) ) if len(inner_collector_edges) > 1: return None @@ -494,7 +493,7 @@ def partition_first_outputs( return None else: # Ensure that there is no path that leads to the second map. - after_intermdiate_node = self.all_nodes_between( + after_intermdiate_node = util.all_nodes_between( graph=state, begin=edge.dst, end=map_entry_2 ) if after_intermdiate_node is not None: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index 3eaebab634..ac83441e90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -19,7 +19,7 @@ import dace from dace import dtypes, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, nodes +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes from . import map_fusion_helper @@ -54,9 +54,9 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): This transformation modifies more nodes than it matches! """ - map_exit1 = transformation.PatternNode(nodes.MapExit) - access_node = transformation.PatternNode(nodes.AccessNode) - map_entry2 = transformation.PatternNode(nodes.MapEntry) + map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) + access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) def __init__( self, @@ -91,6 +91,8 @@ def can_be_applied( - The decomposition exists and at least one of the intermediate sets is not empty. """ + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) map_entry_2: nodes.MapEntry = self.map_entry2 @@ -111,7 +113,8 @@ def can_be_applied( return False if not (output_partition[1] or output_partition[2]): return False - + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) return True def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: @@ -124,28 +127,34 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non By assumption we do not have to rename anything. """ assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) + + # From here on forward we can no longer use `self.map_*`!! + # For some reason they are not stable and change. + map_exit_1: nodes.MapExit = self.map_exit1 + map_entry_2: nodes.MapEntry = self.map_entry2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, ) assert output_partition is not None # Make MyPy happy. pure_outputs, exclusive_outputs, shared_outputs = output_partition - # Must be here to prevent errors - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) - # Handling the outputs if len(exclusive_outputs) != 0: self.handle_intermediate_set( intermediate_outputs=exclusive_outputs, state=graph, sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, is_exclusive_set=True, ) if len(shared_outputs) != 0: @@ -153,14 +162,15 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non intermediate_outputs=shared_outputs, state=graph, sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, is_exclusive_set=False, ) - assert pure_outputs == set(graph.out_edges(self.map_exit1)) + assert pure_outputs == set(graph.out_edges(map_exit_1)) if len(pure_outputs) != 0: self.relocate_nodes( - from_node=self.map_exit1, + from_node=map_exit_1, to_node=map_exit_2, state=graph, sdfg=sdfg, @@ -170,14 +180,14 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # to the first map, now we must move the output of the first map # to the second one, as this one is used. self.relocate_nodes( - from_node=self.map_entry2, + from_node=map_entry_2, to_node=map_entry_1, state=graph, sdfg=sdfg, ) - for node_to_remove in [self.map_exit1, self.map_entry2]: - assert graph.degree(node_to_remove) + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 graph.remove_node(node_to_remove) # Now turn the second output node into the output node of the first Map. @@ -185,11 +195,12 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non def handle_intermediate_set( self, - intermediate_outputs: set[nodes.MultiConnectorEdge], + intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, map_exit_1: nodes.MapExit, map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, is_exclusive_set: bool, ) -> None: """Handle the intermediate output sets. @@ -220,8 +231,6 @@ def handle_intermediate_set( # Note that this is still not enough as the dimensionality might be different. memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) - # Now we will iterate over all intermediate edge and process them. # If not stated otherwise the comments assume that we run in exclusive mode. for out_edge in intermediate_outputs: @@ -241,23 +250,18 @@ def handle_intermediate_set( ) if len(pre_exit_edges) != 1: raise NotImplementedError() - assert pre_exit_edges[0].dst_conn is None pre_exit_edge = pre_exit_edges[0] new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) # Over approximation will leave us with some unneeded size one dimensions. # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[ - int - ] = [] # Dimensions of the original intermediate we squeezed away. + squeezed_dims: list[int] = [] new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. for dim, (proposed_dim_size, full_dim_size) in enumerate( zip(new_inter_shape_raw, inter_shape) ): - # Order of checks is important. - if ( - full_dim_size == 1 - ): # The original array had dimension size 1, so we have to keep it. + # Order of checks is important! + if full_dim_size == 1: # Must be kept! new_inter_shape.append(proposed_dim_size) elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. squeezed_dims.append(dim) @@ -351,11 +355,16 @@ def handle_intermediate_set( # in the second map. We do this by finding the input connectors on the # map entry, such that we know where we have to reroute inside the Map. # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: list[str] = [] + conn_names: set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): if inter_node_out_edge.dst == map_entry_2: assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.append(inter_node_out_edge.dst_conn) + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set # Now we will reroute the connections inside the second map, i.e. # instead of consuming the old intermediate node, they will now @@ -387,7 +396,7 @@ def handle_intermediate_set( ) # Now we do subset modification to ensure that nothing failed. - if isScalar: + if is_scalar: new_inner_memlet.src_subset = "0" else: if new_inner_memlet.src_subset is not None: @@ -404,13 +413,13 @@ def handle_intermediate_set( if consumer_edge.data.subset is not None: consumer_edge.data.subset.pop(squeezed_dims) - # Now remove the inner edge, as it is no longer needed. - state.remove_edge(inner_edge) - - # Now remove the connectors from the map entry. Note that the - # edge ending in `out_conn_name` is still connected to the node. - map_entry_2.remove_out_connector(out_conn_name) + # The edge that leaves the second map entry was already deleted. + # We will now delete the edges that brought the data. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) # TODO: Apply this modification to Memlets # for neighbor in state.all_edges(local_node): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index aa714142c3..f8596106c7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -17,7 +17,7 @@ from typing import Iterable import dace -from dace.sdfg import nodes +from dace.sdfg import graph as dace_graph, nodes def is_nested_sdfg(sdfg: dace.SDFG) -> bool: @@ -93,10 +93,10 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: def find_downstream_consumers( state: dace.SDFGState, - begin: nodes.Node | nodes.MultiConnectorEdge, + begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, reverse: bool = False, -) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: +) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Find all downstream connectors of `begin`. A consumer, in this sense, is any node that is neither an entry nor an exit @@ -115,17 +115,17 @@ def find_downstream_consumers( only_tasklets: Return only Tasklets. reverse: Follow the reverse direction. """ - if isinstance(begin, nodes.MultiConnectorEdge): - to_visit: list[nodes.MultiConnectorEdge] = [begin] + if isinstance(begin, dace_graph.MultiConnectorEdge): + to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] elif reverse: to_visit = list(state.in_edges(begin)) else: to_visit = list(state.out_edges(begin)) - seen: set[nodes.MultiConnectorEdge] = set() - found: set[tuple[nodes.Node, nodes.MultiConnectorEdge]] = set() + seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + found: set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() while len(to_visit) != 0: - curr_edge: nodes.MultiConnectorEdge = to_visit.pop() + curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst if curr_edge in seen: @@ -159,9 +159,9 @@ def find_downstream_consumers( def find_upstream_producers( state: dace.SDFGState, - begin: nodes.Node | nodes.MultiConnectorEdge, + begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, -) -> set[tuple[nodes.Node, nodes.MultiConnectorEdge]]: +) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" return find_downstream_consumers( state=state, From 7f72794c724ffb639681668fbe55681b192e9f67 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 11 Jul 2024 12:26:48 +0200 Subject: [PATCH 134/235] Import changes from dace-fieldview-neighbors --- .../runners/dace_fieldview/__init__.py | 4 +- .../gtir_builtin_translators.py | 597 +++++++++--------- .../dace_fieldview/gtir_python_codegen.py | 2 + .../runners/dace_fieldview/gtir_to_sdfg.py | 34 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 12 +- .../runners/dace_fieldview/utility.py | 14 +- .../{gtir_dace_backend.py => workflow.py} | 7 + 7 files changed, 339 insertions(+), 331 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_dace_backend.py => workflow.py} (88%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py index 18a753a17c..c39c832e88 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -13,9 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dace_backend import ( - build_sdfg_from_gtir, -) +from gt4py.next.program_processors.runners.dace_fieldview.workflow import build_sdfg_from_gtir __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index db40786644..266cbbff1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,13 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + import abc -from typing import Any, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs -from gt4py.eve import concepts from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -31,24 +32,12 @@ from gt4py.next.type_system import type_specifications as ts -IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes -TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] - - -class SDFGBuilder(Protocol): - """Visitor interface available to GTIR-primitive translators.""" - - @abc.abstractmethod - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: - pass +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg - @abc.abstractmethod - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: - pass - @abc.abstractmethod - def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: - pass +IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes +TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] class PrimitiveTranslator(Protocol): @@ -58,7 +47,7 @@ def __call__( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: SDFGBuilder, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, ) -> list[TemporaryData]: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -70,9 +59,6 @@ def __call__( sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available sdfg_builder: The object responsible for visiting child nodes of the primitive node. - reduce_identity: The value of the reduction identity, in case the primitive node - is visited in the context of a reduction expression. This value is used - by the `neighbors` primitive to provide the value of skip neighbors. Returns: A list of data access nodes and the associated GT4Py data type, which provide @@ -82,304 +68,293 @@ def __call__( """ -class AsFieldOp(PrimitiveTranslator): - """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - - @classmethod - def _parse_node_args( - cls, - args: list[gtir.Expr], - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - ) -> list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr]: - stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in args: - fields: list[TemporaryData] = sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) - assert len(fields) == 1 - data_node, arg_type = fields[0] - # require all argument nodes to be data access nodes (no symbols) - assert isinstance(data_node, dace.nodes.AccessNode) - - arg_definition: gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr - if isinstance(arg_type, ts.ScalarType): - arg_definition = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - else: - assert isinstance(arg_type, ts.FieldType) - indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { - dim: gtir_to_tasklet.SymbolExpr( - dace_fieldview_util.get_map_variable(dim), - IteratorIndexDType, - ) - for dim, _, _ in domain - } - arg_definition = gtir_to_tasklet.IteratorExpr( - data_node, - arg_type.dims, - indices, - ) - stencil_args.append(arg_definition) - - return stencil_args - - @classmethod - def _create_temporary_field( - cls, - sdfg: dace.SDFG, - state: dace.SDFGState, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - node_type: ts.ScalarType, - output_desc: dace.data.Data, - output_field_type: ts.DataType, - ) -> tuple[dace.nodes.AccessNode, ts.FieldType]: - domain_dims, domain_lbs, domain_ubs = zip(*domain) - field_dims = list(domain_dims) - field_shape = [ - # diff between upper and lower bound - (ub - lb) - for lb, ub in zip(domain_lbs, domain_ubs) - ] - field_offset: Optional[list[dace.symbolic.SymbolicType]] = None - if any(domain_lbs): - field_offset = [-lb for lb in domain_lbs] - - if isinstance(output_desc, dace.data.Array): - # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) - assert isinstance(output_field_type, ts.FieldType) - # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` - node_type = output_field_type.dtype - field_dims.extend(output_field_type.dims) - field_shape.extend(output_desc.shape) - else: - assert isinstance(output_desc, dace.data.Scalar) - assert isinstance(output_field_type, ts.ScalarType) - # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` - node_type = output_field_type - - # allocate local temporary storage for the result field - temp_name, _ = sdfg.add_temp_transient( - field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset - ) - field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, node_type) - - return field_node, field_type - - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - # expect stencil (represented as a lambda function) as first argument - assert isinstance(stencil_expr, gtir.Lambda) - # the domain of the field operator is passed as second argument - assert isinstance(domain_expr, gtir.FunCall) - - # add local storage to compute the field operator over the given domain - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - domain = dace_fieldview_util.get_domain(domain_expr) - - # first visit the list of arguments and build a symbol map - stencil_args = self._parse_node_args(node.args, sdfg, state, sdfg_builder, domain) - - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) - input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - output_desc = output_expr.node.desc(sdfg) - - # retrieve the tasklet node which writes the result - last_node = state.in_edges(output_expr.node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - last_node_connector = state.in_edges(output_expr.node)[0].src_conn - state.remove_node(output_expr.node) - else: - last_node = output_expr.node - last_node_connector = None - - # allocate local temporary storage for the result field - field_node, field_type = self._create_temporary_field( - sdfg, state, domain, node_type, output_desc, output_expr.field_type - ) - - # assume tasklet with single output - output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] - if isinstance(output_desc, dace.data.Array): - # additional local dimension for neighbors - assert set(output_desc.offset) == {0} - output_subset.extend(f"0:{size}" for size in output_desc.shape) - - # create map range corresponding to the field operator domain - map_ranges = { - dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain +def _parse_arg_expr( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], +) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: + fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + assert len(fields) == 1 + data_node, arg_type = fields[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + return gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace_fieldview_util.get_map_variable(dim), + IteratorIndexDType, + ) + for dim, _, _ in domain } - me, mx = state.add_map("field_op", map_ranges) - - if len(input_connections) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - state.add_nedge(me, last_node, dace.Memlet()) - else: - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) - state.add_memlet_path( - last_node, - mx, - field_node, - src_conn=last_node_connector, - memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), + return gtir_to_tasklet.IteratorExpr( + data_node, + arg_type.dims, + indices, ) - return [(field_node, field_type)] - -class Cond(PrimitiveTranslator): - """Generates the dataflow subgraph for the `cond` builtin function.""" - - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "cond") - assert len(node.args) == 0 - - fun_node = node.fun - assert len(fun_node.args) == 3 - cond_expr, true_expr, false_expr = fun_node.args - - # expect condition as first argument - cond = gtir_python_codegen.get_source(cond_expr) - - # use current head state to terminate the dataflow, and add a entry state - # to connect the true/false branch states as follows: - # - # ------------ - # === | cond | === - # || ------------ || - # \/ \/ - # ------------ ------------- - # | true | | false | - # ------------ ------------- - # || || - # || ------------ || - # ==> | head | <== - # ------------ - # - cond_state = sdfg.add_state_before(state, state.label + "_cond") - sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) - - # expect true branch as second argument - true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) - sdfg.add_edge(true_state, state, dace.InterstateEdge()) - - # and false branch as third argument - false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) - sdfg.add_edge(false_state, state, dace.InterstateEdge()) - - true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) - false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) - - output_nodes = [] - for true_br, false_br in zip(true_br_args, false_br_args, strict=True): - true_br_node, true_br_type = true_br - assert isinstance(true_br_node, dace.nodes.AccessNode) - false_br_node, _ = false_br - assert isinstance(false_br_node, dace.nodes.AccessNode) - desc = true_br_node.desc(sdfg) - assert false_br_node.desc(sdfg) == desc - data_name, _ = sdfg.add_temp_transient_like(desc) - output_nodes.append((state.add_access(data_name), true_br_type)) - - true_br_output_node = true_state.add_access(data_name) - true_state.add_nedge( - true_br_node, - true_br_output_node, - dace.Memlet.from_array(data_name, desc), +def _create_temporary_field( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], + node_type: ts.ScalarType, + output_desc: dace.data.Data, + output_field_type: ts.DataType, +) -> tuple[dace.nodes.AccessNode, ts.FieldType]: + domain_dims, domain_lbs, domain_ubs = zip(*domain) + field_dims = list(domain_dims) + field_shape = [ + # diff between upper and lower bound + (ub - lb) + for lb, ub in zip(domain_lbs, domain_ubs) + ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(domain_lbs): + field_offset = [-lb for lb in domain_lbs] + + if isinstance(output_desc, dace.data.Array): + # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) + assert isinstance(output_field_type, ts.FieldType) + # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` + node_type = output_field_type.dtype + field_dims.extend(output_field_type.dims) + field_shape.extend(output_desc.shape) + else: + assert isinstance(output_desc, dace.data.Scalar) + assert isinstance(output_field_type, ts.ScalarType) + # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` + node_type = output_field_type + + # allocate local temporary storage for the result field + temp_name, _ = sdfg.add_temp_transient( + field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + ) + field_node = state.add_access(temp_name) + field_type = ts.FieldType(field_dims, node_type) + + return field_node, field_type + + +def visit_AsFieldOp( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + fun_node = node.fun + assert len(fun_node.args) == 2 + stencil_expr, domain_expr = fun_node.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, gtir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, gtir.FunCall) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + domain = dace_fieldview_util.get_domain(domain_expr) + + # first visit the list of arguments and build a symbol map + stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + + # represent the field operator as a mapped tasklet graph, which will range over the field domain + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) + input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + output_desc = output_expr.node.desc(sdfg) + + # retrieve the tasklet node which writes the result + last_node = state.in_edges(output_expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = state.in_edges(output_expr.node)[0].src_conn + state.remove_node(output_expr.node) + else: + last_node = output_expr.node + last_node_connector = None + + # allocate local temporary storage for the result field + field_node, field_type = _create_temporary_field( + sdfg, state, domain, node_type, output_desc, output_expr.field_type + ) + + # assume tasklet with single output + output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] + if isinstance(output_desc, dace.data.Array): + # additional local dimension for neighbors + assert set(output_desc.offset) == {0} + output_subset.extend(f"0:{size}" for size in output_desc.shape) + + # create map range corresponding to the field operator domain + map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} + me, mx = state.add_map("field_op", map_ranges) + + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + state.add_nedge(me, last_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, ) + state.add_memlet_path( + last_node, + mx, + field_node, + src_conn=last_node_connector, + memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), + ) + + return [(field_node, field_type)] + + +def visit_Cond( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `cond` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "cond") + assert len(node.args) == 0 + + fun_node = node.fun + assert len(fun_node.args) == 3 + cond_expr, true_expr, false_expr = fun_node.args + + # expect condition as first argument + cond = gtir_python_codegen.get_source(cond_expr) + + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | cond | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # + cond_state = sdfg.add_state_before(state, state.label + "_cond") + sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + + true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) + false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args, strict=True): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, _ = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + desc = true_br_node.desc(sdfg) + assert false_br_node.desc(sdfg) == desc + data_name, _ = sdfg.add_temp_transient_like(desc) + output_nodes.append((state.add_access(data_name), true_br_type)) + + true_br_output_node = true_state.add_access(data_name) + true_state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) - false_br_output_node = false_state.add_access(data_name) - false_state.add_nedge( - false_br_node, - false_br_output_node, - dace.Memlet.from_array(data_name, desc), - ) + false_br_output_node = false_state.add_access(data_name) + false_state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) - return output_nodes + return output_nodes -class SymbolRef(PrimitiveTranslator): +def visit_SymbolRef( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: """Generates the dataflow subgraph for a `ir.SymRef` node.""" + assert isinstance(node, (gtir.Literal, gtir.SymRef)) + + data_type: ts.FieldType | ts.ScalarType + if isinstance(node, gtir.Literal): + sym_value = node.value + data_type = node.type + tasklet_name = "get_literal" + else: + sym_value = str(node.id) + assert sym_value in sdfg_builder.get_symbol_types() + data_type = sdfg_builder.get_symbol_types()[sym_value] + tasklet_name = f"get_{sym_value}" + + if isinstance(data_type, ts.FieldType): + # add access node to current state + sym_node = state.add_access(sym_value) + + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + tasklet_node = state.add_tasklet( + tasklet_name, + {}, + {"__out"}, + f"__out = {sym_value}", + ) + temp_name, _ = sdfg.add_temp_transient((1,), dace_fieldview_util.as_dace_type(data_type)) + sym_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) + + return [(sym_node, data_type)] - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, (gtir.Literal, gtir.SymRef)) - - data_type: ts.FieldType | ts.ScalarType - if isinstance(node, gtir.Literal): - sym_value = node.value - data_type = node.type - tasklet_name = "get_literal" - else: - sym_value = str(node.id) - assert sym_value in sdfg_builder.get_symbol_types() - data_type = sdfg_builder.get_symbol_types()[sym_value] - tasklet_name = f"get_{sym_value}" - - if isinstance(data_type, ts.FieldType): - # add access node to current state - sym_node = state.add_access(sym_value) - - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - tasklet_node = state.add_tasklet( - tasklet_name, - {}, - {"__out"}, - f"__out = {sym_value}", - ) - temp_name, _ = sdfg.add_temp_transient( - (1,), dace_fieldview_util.as_dace_type(data_type) - ) - sym_node = state.add_access(temp_name) - state.add_edge( - tasklet_node, - "__out", - sym_node, - None, - dace.Memlet(data=sym_node.data, subset="0"), - ) - return [(sym_node, data_type)] +if TYPE_CHECKING: + # Use type-checking to assert that all visitor functions implement the `PrimitiveTranslator` protocol + __primitive_translators: list[PrimitiveTranslator] = [ + visit_AsFieldOp, + visit_Cond, + visit_SymbolRef, + ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index fe98d8a98a..d70169bcd1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + from typing import Any import numpy as np diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index ffc1f56027..7c1c904cb1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -12,16 +12,20 @@ # # SPDX-License-Identifier: GPL-3.0-or-later """ -Class to lower GTIR to DaCe SDFG. +Contains visitors to lower GTIR to DaCe SDFG. Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Sequence +from __future__ import annotations + +import abc +from typing import Any, Protocol, Sequence import dace from gt4py import eve +from gt4py.eve import concepts from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -32,7 +36,23 @@ from gt4py.next.type_system import type_specifications as ts, type_translation as tt -class GTIRToSDFG(eve.NodeVisitor, gtir_builtin_translators.SDFGBuilder): +class SDFGBuilder(Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + pass + + @abc.abstractmethod + def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + pass + + @abc.abstractmethod + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + pass + + +class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation @@ -254,9 +274,9 @@ def visit_FunCall( ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.AsFieldOp()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_AsFieldOp(node, sdfg, head_state, self) elif cpm.is_call_to(node.fun, "cond"): - return gtir_builtin_translators.Cond()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_Cond(node, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -274,7 +294,7 @@ def visit_Literal( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) def visit_SymRef( self, @@ -282,4 +302,4 @@ def visit_SymRef( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index c985856d60..88270eb697 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -13,8 +13,10 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + +import dataclasses import itertools -from dataclasses import dataclass from typing import Optional, TypeAlias import dace @@ -31,7 +33,7 @@ from gt4py.next.type_system import type_specifications as ts -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class MemletExpr: """Scalar or array data access thorugh a memlet.""" @@ -39,7 +41,7 @@ class MemletExpr: subset: sbs.Indices | sbs.Range -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class SymbolExpr: """Any symbolic expression that is constant in the context of current SDFG.""" @@ -47,7 +49,7 @@ class SymbolExpr: dtype: dace.typeclass -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ValueExpr: """Result of the computation implemented by a tasklet node.""" @@ -66,7 +68,7 @@ class ValueExpr: IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index d0c7582431..8f775de3f3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + from typing import Any, Mapping, Optional import dace @@ -89,6 +91,9 @@ def get_domain( Specialized visit method for domain expressions. Returns for each domain dimension the corresponding range. + + TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea + would be to cache the results of lowering here (e.g. using `functools.lru_cache`) """ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) @@ -98,11 +103,10 @@ def get_domain( assert len(named_range.args) == 3 axis = named_range.args[0] assert isinstance(axis, gtir.AxisLiteral) - bounds = [] - for arg in named_range.args[1:3]: - sym_str = gtir_python_codegen.get_source(arg) - sym_val = dace.symbolic.SymExpr(sym_str) - bounds.append(sym_val) + bounds = [ + dace.symbolic.SymExpr(gtir_python_codegen.get_source(arg)) + for arg in named_range.args[1:3] + ] dim = gtx_common.Dimension(axis.value, axis.kind) domain.append((dim, bounds[0], bounds[1])) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py similarity index 88% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 3d99225d81..eabfa8f713 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -11,6 +11,13 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +""" +Contains definitions of the workflow steps for GTIR programs with dace as backend for optimization and code generation. + +Note: this module covers the fieldview flavour of GTIR. +""" + +from __future__ import annotations import dace From abf3918b707935336cd7b8ae0cca3e75311acb47 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 11 Jul 2024 13:00:44 +0200 Subject: [PATCH 135/235] Import changes from dace-fieldview-shifts --- .../runners/dace_fieldview/__init__.py | 4 +- .../gtir_builtin_translators.py | 597 +++++++++--------- .../dace_fieldview/gtir_python_codegen.py | 2 + .../runners/dace_fieldview/gtir_to_sdfg.py | 34 +- .../runners/dace_fieldview/gtir_to_tasklet.py | 12 +- .../runners/dace_fieldview/utility.py | 14 +- .../{gtir_dace_backend.py => workflow.py} | 7 + 7 files changed, 339 insertions(+), 331 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/{gtir_dace_backend.py => workflow.py} (88%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py index 18a753a17c..c39c832e88 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -13,9 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.gtir_dace_backend import ( - build_sdfg_from_gtir, -) +from gt4py.next.program_processors.runners.dace_fieldview.workflow import build_sdfg_from_gtir __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index db40786644..266cbbff1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,13 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + import abc -from typing import Any, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs -from gt4py.eve import concepts from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -31,24 +32,12 @@ from gt4py.next.type_system import type_specifications as ts -IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes -TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] - - -class SDFGBuilder(Protocol): - """Visitor interface available to GTIR-primitive translators.""" - - @abc.abstractmethod - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: - pass +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg - @abc.abstractmethod - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: - pass - @abc.abstractmethod - def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: - pass +IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes +TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] class PrimitiveTranslator(Protocol): @@ -58,7 +47,7 @@ def __call__( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: SDFGBuilder, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, ) -> list[TemporaryData]: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -70,9 +59,6 @@ def __call__( sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available sdfg_builder: The object responsible for visiting child nodes of the primitive node. - reduce_identity: The value of the reduction identity, in case the primitive node - is visited in the context of a reduction expression. This value is used - by the `neighbors` primitive to provide the value of skip neighbors. Returns: A list of data access nodes and the associated GT4Py data type, which provide @@ -82,304 +68,293 @@ def __call__( """ -class AsFieldOp(PrimitiveTranslator): - """Generates the dataflow subgraph for the `as_field_op` builtin function.""" - - @classmethod - def _parse_node_args( - cls, - args: list[gtir.Expr], - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - ) -> list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr]: - stencil_args: list[gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr] = [] - for arg in args: - fields: list[TemporaryData] = sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) - assert len(fields) == 1 - data_node, arg_type = fields[0] - # require all argument nodes to be data access nodes (no symbols) - assert isinstance(data_node, dace.nodes.AccessNode) - - arg_definition: gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr - if isinstance(arg_type, ts.ScalarType): - arg_definition = gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) - else: - assert isinstance(arg_type, ts.FieldType) - indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { - dim: gtir_to_tasklet.SymbolExpr( - dace_fieldview_util.get_map_variable(dim), - IteratorIndexDType, - ) - for dim, _, _ in domain - } - arg_definition = gtir_to_tasklet.IteratorExpr( - data_node, - arg_type.dims, - indices, - ) - stencil_args.append(arg_definition) - - return stencil_args - - @classmethod - def _create_temporary_field( - cls, - sdfg: dace.SDFG, - state: dace.SDFGState, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - node_type: ts.ScalarType, - output_desc: dace.data.Data, - output_field_type: ts.DataType, - ) -> tuple[dace.nodes.AccessNode, ts.FieldType]: - domain_dims, domain_lbs, domain_ubs = zip(*domain) - field_dims = list(domain_dims) - field_shape = [ - # diff between upper and lower bound - (ub - lb) - for lb, ub in zip(domain_lbs, domain_ubs) - ] - field_offset: Optional[list[dace.symbolic.SymbolicType]] = None - if any(domain_lbs): - field_offset = [-lb for lb in domain_lbs] - - if isinstance(output_desc, dace.data.Array): - # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) - assert isinstance(output_field_type, ts.FieldType) - # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` - node_type = output_field_type.dtype - field_dims.extend(output_field_type.dims) - field_shape.extend(output_desc.shape) - else: - assert isinstance(output_desc, dace.data.Scalar) - assert isinstance(output_field_type, ts.ScalarType) - # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` - node_type = output_field_type - - # allocate local temporary storage for the result field - temp_name, _ = sdfg.add_temp_transient( - field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset - ) - field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, node_type) - - return field_node, field_type - - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - # expect stencil (represented as a lambda function) as first argument - assert isinstance(stencil_expr, gtir.Lambda) - # the domain of the field operator is passed as second argument - assert isinstance(domain_expr, gtir.FunCall) - - # add local storage to compute the field operator over the given domain - # TODO: use type inference to determine the result type - node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - domain = dace_fieldview_util.get_domain(domain_expr) - - # first visit the list of arguments and build a symbol map - stencil_args = self._parse_node_args(node.args, sdfg, state, sdfg_builder, domain) - - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) - input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) - output_desc = output_expr.node.desc(sdfg) - - # retrieve the tasklet node which writes the result - last_node = state.in_edges(output_expr.node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - last_node_connector = state.in_edges(output_expr.node)[0].src_conn - state.remove_node(output_expr.node) - else: - last_node = output_expr.node - last_node_connector = None - - # allocate local temporary storage for the result field - field_node, field_type = self._create_temporary_field( - sdfg, state, domain, node_type, output_desc, output_expr.field_type - ) - - # assume tasklet with single output - output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] - if isinstance(output_desc, dace.data.Array): - # additional local dimension for neighbors - assert set(output_desc.offset) == {0} - output_subset.extend(f"0:{size}" for size in output_desc.shape) - - # create map range corresponding to the field operator domain - map_ranges = { - dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain +def _parse_arg_expr( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], +) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: + fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + assert len(fields) == 1 + data_node, arg_type = fields[0] + # require all argument nodes to be data access nodes (no symbols) + assert isinstance(data_node, dace.nodes.AccessNode) + + if isinstance(arg_type, ts.ScalarType): + return gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0])) + else: + assert isinstance(arg_type, ts.FieldType) + indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = { + dim: gtir_to_tasklet.SymbolExpr( + dace_fieldview_util.get_map_variable(dim), + IteratorIndexDType, + ) + for dim, _, _ in domain } - me, mx = state.add_map("field_op", map_ranges) - - if len(input_connections) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - state.add_nedge(me, last_node, dace.Memlet()) - else: - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) - state.add_memlet_path( - last_node, - mx, - field_node, - src_conn=last_node_connector, - memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), + return gtir_to_tasklet.IteratorExpr( + data_node, + arg_type.dims, + indices, ) - return [(field_node, field_type)] - -class Cond(PrimitiveTranslator): - """Generates the dataflow subgraph for the `cond` builtin function.""" - - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "cond") - assert len(node.args) == 0 - - fun_node = node.fun - assert len(fun_node.args) == 3 - cond_expr, true_expr, false_expr = fun_node.args - - # expect condition as first argument - cond = gtir_python_codegen.get_source(cond_expr) - - # use current head state to terminate the dataflow, and add a entry state - # to connect the true/false branch states as follows: - # - # ------------ - # === | cond | === - # || ------------ || - # \/ \/ - # ------------ ------------- - # | true | | false | - # ------------ ------------- - # || || - # || ------------ || - # ==> | head | <== - # ------------ - # - cond_state = sdfg.add_state_before(state, state.label + "_cond") - sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) - - # expect true branch as second argument - true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) - sdfg.add_edge(true_state, state, dace.InterstateEdge()) - - # and false branch as third argument - false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) - sdfg.add_edge(false_state, state, dace.InterstateEdge()) - - true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) - false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) - - output_nodes = [] - for true_br, false_br in zip(true_br_args, false_br_args, strict=True): - true_br_node, true_br_type = true_br - assert isinstance(true_br_node, dace.nodes.AccessNode) - false_br_node, _ = false_br - assert isinstance(false_br_node, dace.nodes.AccessNode) - desc = true_br_node.desc(sdfg) - assert false_br_node.desc(sdfg) == desc - data_name, _ = sdfg.add_temp_transient_like(desc) - output_nodes.append((state.add_access(data_name), true_br_type)) - - true_br_output_node = true_state.add_access(data_name) - true_state.add_nedge( - true_br_node, - true_br_output_node, - dace.Memlet.from_array(data_name, desc), +def _create_temporary_field( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] + ], + node_type: ts.ScalarType, + output_desc: dace.data.Data, + output_field_type: ts.DataType, +) -> tuple[dace.nodes.AccessNode, ts.FieldType]: + domain_dims, domain_lbs, domain_ubs = zip(*domain) + field_dims = list(domain_dims) + field_shape = [ + # diff between upper and lower bound + (ub - lb) + for lb, ub in zip(domain_lbs, domain_ubs) + ] + field_offset: Optional[list[dace.symbolic.SymbolicType]] = None + if any(domain_lbs): + field_offset = [-lb for lb in domain_lbs] + + if isinstance(output_desc, dace.data.Array): + # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) + assert isinstance(output_field_type, ts.FieldType) + # TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype` + node_type = output_field_type.dtype + field_dims.extend(output_field_type.dims) + field_shape.extend(output_desc.shape) + else: + assert isinstance(output_desc, dace.data.Scalar) + assert isinstance(output_field_type, ts.ScalarType) + # TODO: enable `assert output_field_type == node_type`, remove variable `dtype` + node_type = output_field_type + + # allocate local temporary storage for the result field + temp_name, _ = sdfg.add_temp_transient( + field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset + ) + field_node = state.add_access(temp_name) + field_type = ts.FieldType(field_dims, node_type) + + return field_node, field_type + + +def visit_AsFieldOp( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + fun_node = node.fun + assert len(fun_node.args) == 2 + stencil_expr, domain_expr = fun_node.args + # expect stencil (represented as a lambda function) as first argument + assert isinstance(stencil_expr, gtir.Lambda) + # the domain of the field operator is passed as second argument + assert isinstance(domain_expr, gtir.FunCall) + + # add local storage to compute the field operator over the given domain + # TODO: use type inference to determine the result type + node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + domain = dace_fieldview_util.get_domain(domain_expr) + + # first visit the list of arguments and build a symbol map + stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + + # represent the field operator as a mapped tasklet graph, which will range over the field domain + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) + input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) + assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) + output_desc = output_expr.node.desc(sdfg) + + # retrieve the tasklet node which writes the result + last_node = state.in_edges(output_expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = state.in_edges(output_expr.node)[0].src_conn + state.remove_node(output_expr.node) + else: + last_node = output_expr.node + last_node_connector = None + + # allocate local temporary storage for the result field + field_node, field_type = _create_temporary_field( + sdfg, state, domain, node_type, output_desc, output_expr.field_type + ) + + # assume tasklet with single output + output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] + if isinstance(output_desc, dace.data.Array): + # additional local dimension for neighbors + assert set(output_desc.offset) == {0} + output_subset.extend(f"0:{size}" for size in output_desc.shape) + + # create map range corresponding to the field operator domain + map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} + me, mx = state.add_map("field_op", map_ranges) + + if len(input_connections) == 0: + # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets + state.add_nedge(me, last_node, dace.Memlet()) + else: + for data_node, data_subset, lambda_node, lambda_connector in input_connections: + memlet = dace.Memlet(data=data_node.data, subset=data_subset) + state.add_memlet_path( + data_node, + me, + lambda_node, + dst_conn=lambda_connector, + memlet=memlet, ) + state.add_memlet_path( + last_node, + mx, + field_node, + src_conn=last_node_connector, + memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)), + ) + + return [(field_node, field_type)] + + +def visit_Cond( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: + """Generates the dataflow subgraph for the `cond` builtin function.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "cond") + assert len(node.args) == 0 + + fun_node = node.fun + assert len(fun_node.args) == 3 + cond_expr, true_expr, false_expr = fun_node.args + + # expect condition as first argument + cond = gtir_python_codegen.get_source(cond_expr) + + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | cond | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # + cond_state = sdfg.add_state_before(state, state.label + "_cond") + sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})")) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + + true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) + false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) + + output_nodes = [] + for true_br, false_br in zip(true_br_args, false_br_args, strict=True): + true_br_node, true_br_type = true_br + assert isinstance(true_br_node, dace.nodes.AccessNode) + false_br_node, _ = false_br + assert isinstance(false_br_node, dace.nodes.AccessNode) + desc = true_br_node.desc(sdfg) + assert false_br_node.desc(sdfg) == desc + data_name, _ = sdfg.add_temp_transient_like(desc) + output_nodes.append((state.add_access(data_name), true_br_type)) + + true_br_output_node = true_state.add_access(data_name) + true_state.add_nedge( + true_br_node, + true_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) - false_br_output_node = false_state.add_access(data_name) - false_state.add_nedge( - false_br_node, - false_br_output_node, - dace.Memlet.from_array(data_name, desc), - ) + false_br_output_node = false_state.add_access(data_name) + false_state.add_nedge( + false_br_node, + false_br_output_node, + dace.Memlet.from_array(data_name, desc), + ) - return output_nodes + return output_nodes -class SymbolRef(PrimitiveTranslator): +def visit_SymbolRef( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> list[TemporaryData]: """Generates the dataflow subgraph for a `ir.SymRef` node.""" + assert isinstance(node, (gtir.Literal, gtir.SymRef)) + + data_type: ts.FieldType | ts.ScalarType + if isinstance(node, gtir.Literal): + sym_value = node.value + data_type = node.type + tasklet_name = "get_literal" + else: + sym_value = str(node.id) + assert sym_value in sdfg_builder.get_symbol_types() + data_type = sdfg_builder.get_symbol_types()[sym_value] + tasklet_name = f"get_{sym_value}" + + if isinstance(data_type, ts.FieldType): + # add access node to current state + sym_node = state.add_access(sym_value) + + else: + # scalar symbols are passed to the SDFG as symbols: build tasklet node + # to write the symbol to a scalar access node + tasklet_node = state.add_tasklet( + tasklet_name, + {}, + {"__out"}, + f"__out = {sym_value}", + ) + temp_name, _ = sdfg.add_temp_transient((1,), dace_fieldview_util.as_dace_type(data_type)) + sym_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + sym_node, + None, + dace.Memlet(data=sym_node.data, subset="0"), + ) + + return [(sym_node, data_type)] - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: SDFGBuilder, - ) -> list[TemporaryData]: - assert isinstance(node, (gtir.Literal, gtir.SymRef)) - - data_type: ts.FieldType | ts.ScalarType - if isinstance(node, gtir.Literal): - sym_value = node.value - data_type = node.type - tasklet_name = "get_literal" - else: - sym_value = str(node.id) - assert sym_value in sdfg_builder.get_symbol_types() - data_type = sdfg_builder.get_symbol_types()[sym_value] - tasklet_name = f"get_{sym_value}" - - if isinstance(data_type, ts.FieldType): - # add access node to current state - sym_node = state.add_access(sym_value) - - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - tasklet_node = state.add_tasklet( - tasklet_name, - {}, - {"__out"}, - f"__out = {sym_value}", - ) - temp_name, _ = sdfg.add_temp_transient( - (1,), dace_fieldview_util.as_dace_type(data_type) - ) - sym_node = state.add_access(temp_name) - state.add_edge( - tasklet_node, - "__out", - sym_node, - None, - dace.Memlet(data=sym_node.data, subset="0"), - ) - return [(sym_node, data_type)] +if TYPE_CHECKING: + # Use type-checking to assert that all visitor functions implement the `PrimitiveTranslator` protocol + __primitive_translators: list[PrimitiveTranslator] = [ + visit_AsFieldOp, + visit_Cond, + visit_SymbolRef, + ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index fe98d8a98a..d70169bcd1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + from typing import Any import numpy as np diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index f94bbc6953..861cc0a475 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -12,16 +12,20 @@ # # SPDX-License-Identifier: GPL-3.0-or-later """ -Class to lower GTIR to DaCe SDFG. +Contains visitors to lower GTIR to DaCe SDFG. Note: this module covers the fieldview flavour of GTIR. """ -from typing import Any, Sequence +from __future__ import annotations + +import abc +from typing import Any, Protocol, Sequence import dace from gt4py import eve +from gt4py.eve import concepts from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -32,7 +36,23 @@ from gt4py.next.type_system import type_specifications as ts -class GTIRToSDFG(eve.NodeVisitor, gtir_builtin_translators.SDFGBuilder): +class SDFGBuilder(Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + pass + + @abc.abstractmethod + def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + pass + + @abc.abstractmethod + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + pass + + +class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. This class is responsible for translation of `ir.Program`, that is the top level representation @@ -235,9 +255,9 @@ def visit_FunCall( ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.AsFieldOp()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_AsFieldOp(node, sdfg, head_state, self) elif cpm.is_call_to(node.fun, "cond"): - return gtir_builtin_translators.Cond()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_Cond(node, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -255,7 +275,7 @@ def visit_Literal( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) def visit_SymRef( self, @@ -263,4 +283,4 @@ def visit_SymRef( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.SymbolRef()(node, sdfg, head_state, self) + return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 25a80892d0..44c3722f58 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -13,7 +13,9 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +from __future__ import annotations + +import dataclasses from typing import Optional, TypeAlias import dace @@ -30,7 +32,7 @@ from gt4py.next.type_system import type_specifications as ts -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class MemletExpr: """Scalar or array data access thorugh a memlet.""" @@ -38,7 +40,7 @@ class MemletExpr: subset: sbs.Indices | sbs.Range -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class SymbolExpr: """Any symbolic expression that is constant in the context of current SDFG.""" @@ -46,7 +48,7 @@ class SymbolExpr: dtype: dace.typeclass -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ValueExpr: """Result of the computation implemented by a tasklet node.""" @@ -65,7 +67,7 @@ class ValueExpr: IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class IteratorExpr: """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 7e9bec2545..34e95506d8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + from typing import Any, Mapping, Optional import dace @@ -85,6 +87,9 @@ def get_domain( Specialized visit method for domain expressions. Returns for each domain dimension the corresponding range. + + TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea + would be to cache the results of lowering here (e.g. using `functools.lru_cache`) """ assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) @@ -94,11 +99,10 @@ def get_domain( assert len(named_range.args) == 3 axis = named_range.args[0] assert isinstance(axis, gtir.AxisLiteral) - bounds = [] - for arg in named_range.args[1:3]: - sym_str = gtir_python_codegen.get_source(arg) - sym_val = dace.symbolic.SymExpr(sym_str) - bounds.append(sym_val) + bounds = [ + dace.symbolic.SymExpr(gtir_python_codegen.get_source(arg)) + for arg in named_range.args[1:3] + ] dim = gtx_common.Dimension(axis.value, axis.kind) domain.append((dim, bounds[0], bounds[1])) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py similarity index 88% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 3d99225d81..eabfa8f713 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -11,6 +11,13 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +""" +Contains definitions of the workflow steps for GTIR programs with dace as backend for optimization and code generation. + +Note: this module covers the fieldview flavour of GTIR. +""" + +from __future__ import annotations import dace From 4a2ccaaf87fe1f9ad16e47392ec3d6711bb44b85 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 13:12:45 +0200 Subject: [PATCH 136/235] More debugger friendly. --- .../transformations/auto_opt.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 444e5506b0..d6512671b5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -20,6 +20,8 @@ from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize +from .map_seriall_fusion import SerialMapFusion + def dace_auto_optimize( sdfg: dace.SDFG, @@ -64,18 +66,25 @@ def gt_auto_optimize( device: The device for which we should optimize. """ - # Initial cleaning - sdfg.simplify() + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) - # Due to the structure of the generated SDFG getting rid of Maps, - # i.e. fusing them, is the best we can currently do. - sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) + # Initial cleaning + sdfg.simplify() - # These are the part that we copy from DaCe built in auto optimization. - dace_aoptimize.set_fast_implementations(sdfg, device) - dace_aoptimize.make_transients_persistent(sdfg, device) - dace_aoptimize.move_small_arrays_to_stack(sdfg) + # Due to the structure of the generated SDFG getting rid of Maps, + # i.e. fusing them, is the best we can currently do. + if kwargs.get("use_dace_fusion", False): + sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) + else: + xform = SerialMapFusion() + sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) - sdfg.simplify() + # These are the part that we copy from DaCe built in auto optimization. + dace_aoptimize.set_fast_implementations(sdfg, device) + dace_aoptimize.make_transients_persistent(sdfg, device) + dace_aoptimize.move_small_arrays_to_stack(sdfg) - return sdfg + sdfg.simplify() + + return sdfg From d353d0e310979e70b8dfe8679e01008bd6a55aa6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 13:13:06 +0200 Subject: [PATCH 137/235] It should now work and I figured out why it was not working before. However, this would mean that the IDs are not stable. --- .../dace_fieldview/transformations/map_fusion_helper.py | 2 +- .../dace_fieldview/transformations/map_seriall_fusion.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 32ef6a6be5..00b7d0ab7a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -208,8 +208,8 @@ def relocate_nodes( state.out_degree(from_node) == 0 ), f"After moving source node '{from_node}' still has an output degree of {state.out_degree(from_node)}" + @staticmethod def map_parameter_compatible( - self, map_1: nodes.Map, map_2: nodes.Map, state: Union[SDFGState, SDFG], diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index ac83441e90..491fdddb48 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -130,8 +130,9 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non assert isinstance(self.map_exit1, nodes.MapExit) assert isinstance(self.map_entry2, nodes.MapEntry) - # From here on forward we can no longer use `self.map_*`!! - # For some reason they are not stable and change. + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here. map_exit_1: nodes.MapExit = self.map_exit1 map_entry_2: nodes.MapEntry = self.map_entry2 map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) @@ -193,8 +194,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # Now turn the second output node into the output node of the first Map. map_exit_2.map = map_entry_1.map + @staticmethod def handle_intermediate_set( - self, intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, @@ -272,7 +273,7 @@ def handle_intermediate_set( # It will only have the shape `new_inter_shape` which is basically its # output within one Map iteration. # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" # Now generate the intermediate data container. if len(new_inter_shape) == 0: From 073065d0721e1e3a3c840894a6192eb4bcafa888 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 14:25:59 +0200 Subject: [PATCH 138/235] A fix. --- .../dace_fieldview/transformations/map_fusion_helper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 00b7d0ab7a..c5d09d73b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -194,12 +194,14 @@ def relocate_nodes( old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix new_conn = to_node.next_connector(old_conn) + to_node.add_in_connector("IN_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + from_node.remove_out_connector("OUT_" + old_conn) + to_node.add_out_connector("OUT_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) assert ( state.in_degree(from_node) == 0 From 19be2c46091b07212d4799c833fdb1b1e2e2e09f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 15:02:25 +0200 Subject: [PATCH 139/235] Added some "test" for the merger. --- my_playground/map_fusion_test.py | 327 +++++++++++++++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 my_playground/map_fusion_test.py diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py new file mode 100644 index 0000000000..618802e200 --- /dev/null +++ b/my_playground/map_fusion_test.py @@ -0,0 +1,327 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Simple tests top verify the map fusion tests. +""" + +import dace +import copy +from gt4py.next.common import NeighborTable +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.runners import dace_fieldview as dace_backend +from gt4py.next.type_system import type_specifications as ts +from functools import reduce +import numpy as np + +from typing import Sequence, Any + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations, # noqa: F401 [unused-import] # For development. +) + +from simple_icon_mesh import ( + IDim, # Dimensions + JDim, + KDim, + EdgeDim, + VertexDim, + CellDim, + ECVDim, + E2C2VDim, + NbCells, # Constants of the size + NbEdges, + NbVertices, + E2C2VDim, # Offsets + E2C2V, + SIZE_TYPE, # Type definitions + E2C2V_connectivity, + E2ECV_connectivity, + make_syms, # Helpers +) + +# For cartesian stuff. +N = 10 +IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +IJFTYPE = ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + + +def _perform_test( + sdfg: dace.SDFG, ref: Any, return_names: Sequence[str] | str, args: dict[str, Any] +) -> dace.SDFG: + unopt_sdfg = copy.deepcopy(sdfg) + + if not isinstance(ref, list): + ref = [ref] + if isinstance(return_names, str): + return_names = [return_names] + + SYMBS = make_syms(**args) + + # Call the unoptimized version of the SDFG + unopt_sdfg(**args, **SYMBS) + unopt_res = [args[name] for name in return_names] + + assert np.allclose(ref, unopt_res), "The unoptimized verification failed." + + # Reset the results + for name in return_names: + args[name][:] = 0 + assert not np.allclose(ref, unopt_res) + + # Now perform the optimization + opt_sdfg = copy.deepcopy(sdfg) + transformations.gt_auto_optimize(opt_sdfg) + opt_sdfg.validate() + opt_sdfg(**args, **SYMBS) + opt_res = [args[name] for name in return_names] + + assert np.allclose(ref, opt_res), "The optimized verification failed." + + return opt_sdfg + + +def _count_nodes( + sdfg: dace.SDFG, + state: dace.SDFGState | None = None, + node_type: Sequence[type] | type = dace_nodes.MapEntry, +) -> int: + states = sdfg.states() if state is None else [state] + found_matches = 0 + for state_nodes in states: + for node in state_nodes.nodes(): + if isinstance(node, node_type): + found_matches += 1 + return found_matches + + +###################### +# TESTS + + +def exclusive_only(): + """Tests the sxclusive set merging mechanism only.""" + + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 2.0)), + domain, + ) + )("x"), + ) + + a = np.random.rand(N) + + testee = itir.Program( + id=f"sum_3fields_1", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil1, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg.validate() + + assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 3 + assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 + + a = np.random.rand(N) + res1 = np.empty_like(a) + + args = { + "x": a, + "z": res1, + "size": N, + } + return_names = ["z"] + + opt_sdfg = _perform_test( + sdfg=sdfg, + ref=a + 3.0, + return_names="z", + args=args, + ) + + assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 3 + assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 + + +def exclusive_only_2(): + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "y", + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 2.0)), + domain, + ) + )("x"), + ) + + a = np.random.rand(N) + b = np.random.rand(N) + + testee = itir.Program( + id=f"sum_3fields_1", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="y", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil1, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg.validate() + + assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 + + a = np.random.rand(N) + res1 = np.empty_like(a) + + args = { + "x": a, + "y": b, + "z": res1, + "size": N, + } + return_names = ["z"] + + opt_sdfg = _perform_test( + sdfg=sdfg, + ref=(a + b + 2.0), + return_names="z", + args=args, + ) + + assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 + + +def intermediate_branch(): + sdfg = dace.SDFG("intermediate") + state = sdfg.add_state("state") + + ac: list[nodes.AccessNode] = [] + for i in range(3): + name = "input" if i == 0 else f"output{i-1}" + sdfg.add_array( + name, + shape=(N,), + dtype=dace.float64, + transient=False, # All are global. + ) + ac.append(state.add_access(name)) + sdfg.add_array( + name="tmp", + shape=(N,), + dtype=dace.float64, + transient=True, + ) + ac.append(state.add_access("tmp")) + + state.add_mapped_tasklet( + "first_add", + map_ranges=[("i", f"0:{N}")], + code="__out = __in0 + 1.0", + inputs=dict(__in0=dace.Memlet("input[i]")), + outputs=dict(__out=dace.Memlet("tmp[i]")), + input_nodes=dict(input=ac[0]), + output_nodes=dict(tmp=ac[-1]), + external_edges=True, + ) + + for i in range(2): + state.add_mapped_tasklet( + f"level_{i}_add", + map_ranges=[("i", f"0:{N}")], + code=f"__out = __in0 + {i+3}", + inputs=dict(__in0=dace.Memlet("tmp[i]")), + outputs=dict(__out=dace.Memlet(f"output{i}[i]")), + input_nodes=dict(tmp=ac[-1]), + output_nodes={f"output{i}": ac[1 + i]}, + external_edges=True, + ) + + assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 3 + + a = np.random.rand(N) + ref0 = a + 1 + 3 + ref1 = a + 1 + 4 + + res0 = np.empty_like(a) + res1 = np.empty_like(a) + + args = { + "input": a, + "output0": res0, + "output1": res1, + } + return_names = ["output0", "output1"] + + opt_sdfg = _perform_test( + sdfg=sdfg, + ref=[ref0, ref1], + return_names=return_names, + args=args, + ) + assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 + + +if "__main__" == __name__: + # exclusive_only() + # exclusive_only_2() + intermediate_branch() + print("SUCCESS") From 5237b1399f432bde28e7242b75f0c88c71ab6ac6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 11 Jul 2024 15:10:08 +0200 Subject: [PATCH 140/235] Now the nabla4 optimizes with my fusion operation. --- my_playground/nabla4.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py index a497bda6a4..f3252fe85b 100644 --- a/my_playground/nabla4.py +++ b/my_playground/nabla4.py @@ -25,6 +25,8 @@ from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts +from gt4py.next.program_processors.runners.dace_fieldview import transformations as fw_trans + from simple_icon_mesh import ( IDim, # Dimensions JDim, @@ -469,11 +471,23 @@ def verify_nabla4( SYMBS = make_syms(**call_args) - sdfg(**call_args, **SYMBS) - ref = nabla4_np(**call_args, **offset_provider) + org_sdfg = copy.deepcopy(sdfg) + + for i in range(2): + sdfg = copy.deepcopy(org_sdfg) + if i != 0: + fw_trans.gt_auto_optimize(sdfg) + + sdfg.view() - assert np.allclose(ref, nab4) - print(f"Version({version}): Succeeded") + sdfg(**call_args, **SYMBS) + ref = nabla4_np(**call_args, **offset_provider) + assert np.allclose(ref, nab4) + nab4[:] = 0 + if i == 0: + print(f"Version({version} | unoptimized): Succeeded") + else: + print(f"Version({version} | optimized): Succeeded") if "__main__" == __name__: From 489bb4a0063d454c71302dd87c9541ff3a69474a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 08:16:44 +0200 Subject: [PATCH 141/235] Added some more test. --- my_playground/map_fusion_test.py | 154 ++++++++++++++++++++++++++++++- 1 file changed, 152 insertions(+), 2 deletions(-) diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py index 618802e200..737d56efa5 100644 --- a/my_playground/map_fusion_test.py +++ b/my_playground/map_fusion_test.py @@ -320,8 +320,158 @@ def intermediate_branch(): assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 +def shifting(): + """Tests what happens if we have a sift.""" + + IOffset = 3 + + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size"), + im.call("named_range")(itir.AxisLiteral(value=JDim.value), 0, "size"), + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + domain, + ) + )( + "y", + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.deref(im.shift("IDim", IOffset)("a"))), + domain, + ) + )("x"), + ) + + testee = itir.Program( + id=f"shift_test", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IJFTYPE), + itir.Sym(id="y", type=IJFTYPE), + itir.Sym(id="z", type=IJFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil1, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + offset_provider = { + "IDim": IDim, + "JDim": JDim, + } + sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) + sdfg.validate() + + assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 + + x = np.random.rand(N + 2 * IOffset, N) + y = np.random.rand(N, N) + z = np.empty_like(y) + ref = x[IOffset : (IOffset + N), :] + y + + args = { + "x": x, + "y": y, + "z": z, + "size": N, + } + return_names = ["z"] + + opt_sdfg = _perform_test( + sdfg=sdfg, + ref=ref, + return_names="z", + args=args, + ) + + assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 + assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 + + +def non_zero_start(): + """Tests what happens if there are two maps, that does not start at zero.""" + + # NOTE: + # Currently the translator has an error, that leads to an invalid SDFG. + # However, I am now, also unsure what the result should be in the first place. + + dom_start = 3 + domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value=IDim.value), dom_start, "size") + ) + stencil1 = im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 1.0)), + domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("a")(im.plus(im.deref("a"), 2.0)), + domain, + ) + )("x"), + ) + + testee = itir.Program( + id=f"non_zero_start_test", + function_definitions=[], + params=[ + itir.Sym(id="x", type=IFTYPE), + itir.Sym(id="z", type=IFTYPE), + itir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + itir.SetAt( + expr=stencil1, + domain=domain, + target=itir.SymRef(id="z"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg.validate() + + assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 3 + assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 + + x = np.random.rand(N) + z = np.empty_like(x) + + args = { + "x": x, + "z": z, + "size": N, + } + ref = x[3:N] + 3.0 + return_names = ["z"] + + opt_sdfg = _perform_test( + sdfg=sdfg, + ref=ref, + return_names="z", + args=args, + ) + + assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 3 + assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 + + if "__main__" == __name__: - # exclusive_only() - # exclusive_only_2() + exclusive_only() + exclusive_only_2() intermediate_branch() + shifting() + # non_zero_start() print("SUCCESS") From 4d1a3cc24bfb5d5fd281b9bd8682aeeb6ec7971e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 11:01:17 +0200 Subject: [PATCH 142/235] Fixed some small problem in detecting recursive dataflow. --- .../transformations/map_fusion_helper.py | 75 +++++++++++-------- .../transformations/map_seriall_fusion.py | 54 ++++++++----- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index c5d09d73b0..6ef4e473d2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -28,14 +28,14 @@ @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): - """ - Contains common part of the map fusion for parallel and serial map fusion. - - See also [this HackMD document](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - about the underlying assumption this transformation makes. + """Contains common part of the map fusion for parallel and serial map fusion. - After every transformation that manipulates the state machine, you shouls recreate - the transformation. + The transformation assumes that the SDFG obeys the principals outlined in [this + HackMD document](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + The main advantage of this structure is, that it is rather easy to determine + if a transient can be used. This check, performed by `is_interstate_transient()`, + is speed up by cashing some computation, thus such an object should not be used + after interstate optimizations were applied to the SDFG. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. @@ -90,8 +90,8 @@ def can_be_fused( ) -> bool: """Performs some checks if the maps can be fused. - Essentially, this function only checks constrains that does not depend if - a serial or a parallel map fusion happens. Thus it tests: + Essentially, this function only checks constrains that are common between + the serial and parallel map fusion process. It tests: - The scope of the maps. - The scheduling of the maps. - The map parameters. @@ -99,7 +99,6 @@ def can_be_fused( However, for performance reasons, the function does not check if the node decomposition exists. """ - if self.only_inner_maps and self.only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") @@ -203,12 +202,10 @@ def relocate_nodes( for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) - assert ( - state.in_degree(from_node) == 0 - ), f"After moving source node '{from_node}' still has an input degree of {state.in_degree(from_node)}" - assert ( - state.out_degree(from_node) == 0 - ), f"After moving source node '{from_node}' still has an output degree of {state.out_degree(from_node)}" + assert state.in_degree(from_node) == 0 + assert len(from_node.in_connectors) == 0 + assert state.out_degree(from_node) == 0 + assert len(from_node.out_connectors) == 0 @staticmethod def map_parameter_compatible( @@ -252,20 +249,25 @@ def is_interstate_transient( self, transient: Union[str, nodes.AccessNode], sdfg: dace.SDFG, + state: dace.SDFGState | None = None, ) -> bool: """Tests if `transient` is an interstate transient, an can not be removed. - Essentially this function checks if a transient is needed in a + Essentially this function checks if a transient might be needed in a different state in the SDFG, because it transmit information from - one state to the other. However, this function only checks if the - transient is needed for transmitting information between states. - It does _not_ check if the transient is needed multiple times within - the state. This case can be detected by checking the number of outgoing - edges. + one state to the other. If only the name of the transient is passed, + then the function will only check if it is used in another state. + If the access node and the state are passed the function will also + check if it is used inside the state. Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. + state: If given the state the node is located in. + + Note: + This function build upon the structure of the SDFG that is outlined + in the HackMD document. """ # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) @@ -283,8 +285,12 @@ def is_interstate_transient( # If a scalar is not a source node then it is not included in this set. # Thus we do not have to look for it, instead we will check for them # explicitly. + def decent(node: nodes.Node, graph: Any) -> bool: + return not isinstance(node, nodes.NestedSDFG) + shared_sdfg_transients = set() - for state in sdfg.states(): + # TODO(phimuell): use `sdfg.all_nodes_recursive(decent)` if it is available. + for state in sdfg.all_states(): shared_sdfg_transients.update( filter( lambda node: isinstance(node, nodes.AccessNode) @@ -295,13 +301,17 @@ def is_interstate_transient( self.shared_transients[sdfg] = shared_sdfg_transients if isinstance(transient, nodes.AccessNode): + if state is not None: + # Rule 8: Used within the state. + if state.out_degree(transient) > 1: + return True transient = transient.data - desc: data.Data = sdfg.arrays[transient] + desc: data.Data = sdfg.arrays[transient] if not desc.transient: - return False + return True if isinstance(desc, data.Scalar): - return False + return True return transient in shared_sdfg_transients def partition_first_outputs( @@ -408,13 +418,13 @@ def partition_first_outputs( if len(inner_collector_edges) > 1: return None - # For us an intermediate node must always be an access node, pointing to a - # transient value, since it is the only thing that we know how to handle. + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. if not isinstance(intermediate_node, nodes.AccessNode): return None intermediate_desc: data.Data = intermediate_node.desc(sdfg) - if not intermediate_desc.transient: - return None if isinstance(intermediate_desc, data.View): return None @@ -464,9 +474,8 @@ def partition_first_outputs( # Note that "remove" has a special meaning here, regardless of the # output of the check function, from within the second map we remove # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?". - # NOTE: The case "used in this state" is handled above!! - if self.is_interstate_transient(intermediate_node, sdfg): + # reconstruct it after the second map again?" + if self.is_interstate_transient(intermediate_node, sdfg, state): shared_outputs.add(out_edge) else: exclusive_outputs.add(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index 491fdddb48..d23e1e758c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -51,7 +51,9 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): only_toplevel_maps: Only consider Maps that are at the top. Notes: - This transformation modifies more nodes than it matches! + - This transformation modifies more nodes than it matches! + - The consolidate edge transformation (part of simplify) is probably + harmful to the applicability of this transformation. """ map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) @@ -96,6 +98,7 @@ def can_be_applied( map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) map_entry_2: nodes.MapEntry = self.map_entry2 + # This essentially test the structural properties of the two Maps. if not self.can_be_fused( map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg ): @@ -113,8 +116,6 @@ def can_be_applied( return False if not (output_partition[1] or output_partition[2]): return False - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) return True def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: @@ -125,14 +126,17 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non the two intermediate sets are handled by `handle_intermediate_set()`. By assumption we do not have to rename anything. - """ - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ # NOTE: `self.map_*` actually stores the ID of the node. # once we start adding and removing nodes it seems that their ID changes. # Thus we have to save them here. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) map_exit_1: nodes.MapExit = self.map_exit1 map_entry_2: nodes.MapEntry = self.map_entry2 map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) @@ -204,23 +208,33 @@ def handle_intermediate_set( map_exit_2: nodes.MapExit, is_exclusive_set: bool, ) -> None: - """Handle the intermediate output sets. + """This function handles the intermediate sets. The function is able to handle both the shared and exclusive intermediate - output set that was computed by `partition_first_outputs()`. Which one is - handled is indicated by `is_exclusive_set`. - - The main difference is that in exclusive mode is that the intermediate node - will be fully removed from the SDFG. However, in shared mode, the intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode is that the intermediate node will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + However, the function just performs some rewiring of the outputs and + manipulation of the intermediate node set. Args: - intermediateOutputSet: The set of edges that are intermediate outputs of the first Map. - graph: The graph we operate on. - sdfg: The SDFG we operate on. - mapExit1: The node that serves as exit node of the first Map. - mapEntry2: The node that serves as entry node of the second Map. - mapExit2: The node that serves as exit node of the second Map. - isExclusiveSet: If `True` process the exclusive set. + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_1: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not be to be valid and + after this function has run the state is invalid. + The function is static and the map nodes have to be explicitly passed + because the `self.map_*` properties are modified by the modification. + This is a known behaviour in DaCe. + + Todo: + Rewrite using `MemletTree`. """ # Essentially this function removes the AccessNode between the two maps. From 2da74533b1954f2d53fc79a5370507e1fa55a37e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 11:06:47 +0200 Subject: [PATCH 143/235] Made some imporvements to the test. --- my_playground/map_fusion_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py index 737d56efa5..20f8b1d5b6 100644 --- a/my_playground/map_fusion_test.py +++ b/my_playground/map_fusion_test.py @@ -447,14 +447,15 @@ def non_zero_start(): assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 x = np.random.rand(N) - z = np.empty_like(x) + z = np.zeros_like(x) args = { "x": x, "z": z, "size": N, } - ref = x[3:N] + 3.0 + ref = np.zeros_like(x) + ref[dom_start:N] = x[dom_start:N] + 3.0 return_names = ["z"] opt_sdfg = _perform_test( From ba97fd2b5f9bef7e5d16f3416b2ee1ebce03d0f1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 11:15:21 +0200 Subject: [PATCH 144/235] Made some comments better. --- my_playground/map_fusion_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py index 20f8b1d5b6..69bef188c5 100644 --- a/my_playground/map_fusion_test.py +++ b/my_playground/map_fusion_test.py @@ -323,6 +323,8 @@ def intermediate_branch(): def shifting(): """Tests what happens if we have a sift.""" + # Currently the transformer fails to parse the IR. + IOffset = 3 domain = im.call("cartesian_domain")( @@ -400,10 +402,6 @@ def shifting(): def non_zero_start(): """Tests what happens if there are two maps, that does not start at zero.""" - # NOTE: - # Currently the translator has an error, that leads to an invalid SDFG. - # However, I am now, also unsure what the result should be in the first place. - dom_start = 3 domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value=IDim.value), dom_start, "size") @@ -473,6 +471,6 @@ def non_zero_start(): exclusive_only() exclusive_only_2() intermediate_branch() - shifting() - # non_zero_start() + # shifting() + non_zero_start() print("SUCCESS") From 9301dbe794517bca9a71cc611396d78a9dbc8518 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 12 Jul 2024 11:20:08 +0200 Subject: [PATCH 145/235] Import changes from branch dace-fieldview-neighbors --- .../gtir_builtin_translators.py | 10 +-- .../runners/dace_fieldview/gtir_to_sdfg.py | 89 +++++++++++++++---- .../runners/dace_fieldview/gtir_to_tasklet.py | 54 +++++++---- 3 files changed, 113 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 266cbbff1a..3ace4a12a5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -173,7 +173,7 @@ def visit_AsFieldOp( stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder) input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -202,7 +202,7 @@ def visit_AsFieldOp( # create map range corresponding to the field operator domain map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} - me, mx = state.add_map("field_op", map_ranges) + me, mx = sdfg_builder.add_map("field_op", state, map_ranges) if len(input_connections) == 0: # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets @@ -321,8 +321,7 @@ def visit_SymbolRef( tasklet_name = "get_literal" else: sym_value = str(node.id) - assert sym_value in sdfg_builder.get_symbol_types() - data_type = sdfg_builder.get_symbol_types()[sym_value] + data_type = sdfg_builder.get_symbol_type(sym_value) tasklet_name = f"get_{sym_value}" if isinstance(data_type, ts.FieldType): @@ -332,8 +331,9 @@ def visit_SymbolRef( else: # scalar symbols are passed to the SDFG as symbols: build tasklet node # to write the symbol to a scalar access node - tasklet_node = state.add_tasklet( + tasklet_node = sdfg_builder.add_tasklet( tasklet_name, + state, {}, {"__out"}, f"__out = {sym_value}", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 7c1c904cb1..77dd56b232 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -20,12 +20,14 @@ from __future__ import annotations import abc -from typing import Any, Protocol, Sequence +import dataclasses +from typing import Any, Dict, List, Protocol, Sequence, Set, Tuple, Union import dace from gt4py import eve from gt4py.eve import concepts +from gt4py.eve.utils import UIDGenerator from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -36,15 +38,54 @@ from gt4py.next.type_system import type_specifications as ts, type_translation as tt -class SDFGBuilder(Protocol): - """Visitor interface available to GTIR-primitive translators.""" +class DataflowBuilder(Protocol): + """Visitor interface to build a dataflow subgraph.""" + + @abc.abstractmethod + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + pass @abc.abstractmethod - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + def unique_map_name(self, name: str) -> str: pass @abc.abstractmethod - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + def unique_tasklet_name(self, name: str) -> str: + pass + + def add_map( + self, + name: str, + state: dace.SDFGState, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" + unique_name = self.unique_map_name(name) + return state.add_map(unique_name, ndrange, **kwargs) + + def add_tasklet( + self, + name: str, + state: dace.SDFGState, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + + +class SDFGBuilder(DataflowBuilder, Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: pass @abc.abstractmethod @@ -52,6 +93,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: pass +@dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -66,20 +108,29 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] - - def __init__( - self, - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], - ): - self.offset_provider = offset_provider - self.symbol_types = {} - - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: - return self.offset_provider - - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: - return self.symbol_types + symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( + default_factory=lambda: {} + ) + map_uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="map") + ) + tesklet_uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="tlet") + ) + + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + assert offset in self.offset_provider + return self.offset_provider[offset] + + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: + assert symbol_name in self.symbol_types + return self.symbol_types[symbol_name] + + def unique_map_name(self, name: str) -> str: + return f"{self.map_uids.sequential_id()}_{name}" + + def unique_tasklet_name(self, name: str) -> str: + return f"{self.tesklet_uids.sequential_id()}_{name}" def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 88270eb697..b59fd80881 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -17,7 +17,7 @@ import dataclasses import itertools -from typing import Optional, TypeAlias +from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union import dace import dace.subsets as sbs @@ -28,6 +28,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, + gtir_to_sdfg, utility as dace_fieldview_util, ) from gt4py.next.type_system import type_specifications as ts @@ -87,7 +88,7 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] + subgraph_builder: gtir_to_sdfg.DataflowBuilder input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -95,11 +96,11 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], + subgraph_builder: gtir_to_sdfg.DataflowBuilder, ): self.sdfg = sdfg self.state = state - self.offset_provider = offset_provider + self.subgraph_builder = subgraph_builder self.input_connections = [] self.symbol_map = {} @@ -112,6 +113,29 @@ def _add_entry_memlet_path( ) -> None: self.input_connections.append((src, src_subset, dst_node, dst_conn)) + def _add_map( + self, + name: str, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Helper method to add a map with unique ame in current state.""" + return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) + + def _add_tasklet( + self, + name: str, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Helper method to add a tasklet with unique ame in current state.""" + return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs) + def _get_tasklet_result( self, dtype: dace.typeclass, @@ -161,7 +185,7 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr: else IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices ) - deref_node = self.state.add_tasklet( + deref_node = self._add_tasklet( "deref_field_indirection", {"field"} | set(index_connectors), {"val"}, @@ -239,21 +263,21 @@ def _make_cartesian_shift( # the offset needs to be calculate by means of a tasklet new_index_connector = "shifted_index" if isinstance(index_expr, SymbolExpr): - dynamic_offset_tasklet = self.state.add_tasklet( + dynamic_offset_tasklet = self._add_tasklet( "dynamic_offset", {"offset"}, {new_index_connector}, f"{new_index_connector} = {index_expr.value} + offset", ) elif isinstance(offset_expr, SymbolExpr): - dynamic_offset_tasklet = self.state.add_tasklet( + dynamic_offset_tasklet = self._add_tasklet( "dynamic_offset", {"index"}, {new_index_connector}, f"{new_index_connector} = index + {offset_expr}", ) else: - dynamic_offset_tasklet = self.state.add_tasklet( + dynamic_offset_tasklet = self._add_tasklet( "dynamic_offset", {"index", "offset"}, {new_index_connector}, @@ -306,7 +330,7 @@ def _make_dynamic_neighbor_offset( or computed byanother tasklet (`ValueExpr`). """ new_index_connector = "neighbor_index" - tasklet_node = self.state.add_tasklet( + tasklet_node = self._add_tasklet( "dynamic_neighbor_offset", {"table", "offset"}, {new_index_connector}, @@ -390,7 +414,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: assert isinstance(head[0], gtir.OffsetLiteral) offset = head[0].value assert isinstance(offset, str) - offset_provider = self.offset_provider[offset] + offset_provider = self.subgraph_builder.get_offset_provider(offset) # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr: IteratorIndexExpr if isinstance(head[1], gtir.OffsetLiteral): @@ -443,9 +467,9 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) out_connector = "result" - tasklet_node = self.state.add_tasklet( + tasklet_node = self._add_tasklet( builtin_name, - node_connections.keys(), + set(node_connections.keys()), {out_connector}, "{} = {}".format(out_connector, code), ) @@ -494,7 +518,7 @@ def visit_Lambda( if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node output_dtype = output_expr.node.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_entry_memlet_path( output_expr.node, output_expr.subset, @@ -504,9 +528,7 @@ def visit_Lambda( else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype - tasklet_node = self.state.add_tasklet( - "write", {}, {"__out"}, f"__out = {output_expr.value}" - ) + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: From 7f60cfe2ba3fd5ab8dbc91dd942b6bf58fcbf8b6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 12 Jul 2024 11:24:33 +0200 Subject: [PATCH 146/235] Import changes from branch dace-fieldview-shifts --- .../gtir_builtin_translators.py | 10 +-- .../runners/dace_fieldview/gtir_to_sdfg.py | 89 +++++++++++++++---- .../runners/dace_fieldview/gtir_to_tasklet.py | 42 ++++++--- 3 files changed, 107 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 266cbbff1a..3ace4a12a5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -173,7 +173,7 @@ def visit_AsFieldOp( stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder.get_offset_provider()) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder) input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -202,7 +202,7 @@ def visit_AsFieldOp( # create map range corresponding to the field operator domain map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} - me, mx = state.add_map("field_op", map_ranges) + me, mx = sdfg_builder.add_map("field_op", state, map_ranges) if len(input_connections) == 0: # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets @@ -321,8 +321,7 @@ def visit_SymbolRef( tasklet_name = "get_literal" else: sym_value = str(node.id) - assert sym_value in sdfg_builder.get_symbol_types() - data_type = sdfg_builder.get_symbol_types()[sym_value] + data_type = sdfg_builder.get_symbol_type(sym_value) tasklet_name = f"get_{sym_value}" if isinstance(data_type, ts.FieldType): @@ -332,8 +331,9 @@ def visit_SymbolRef( else: # scalar symbols are passed to the SDFG as symbols: build tasklet node # to write the symbol to a scalar access node - tasklet_node = state.add_tasklet( + tasklet_node = sdfg_builder.add_tasklet( tasklet_name, + state, {}, {"__out"}, f"__out = {sym_value}", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 861cc0a475..dd8903e0c2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -20,12 +20,14 @@ from __future__ import annotations import abc -from typing import Any, Protocol, Sequence +import dataclasses +from typing import Any, Dict, List, Protocol, Sequence, Set, Tuple, Union import dace from gt4py import eve from gt4py.eve import concepts +from gt4py.eve.utils import UIDGenerator from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -36,15 +38,54 @@ from gt4py.next.type_system import type_specifications as ts -class SDFGBuilder(Protocol): - """Visitor interface available to GTIR-primitive translators.""" +class DataflowBuilder(Protocol): + """Visitor interface to build a dataflow subgraph.""" + + @abc.abstractmethod + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + pass @abc.abstractmethod - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: + def unique_map_name(self, name: str) -> str: pass @abc.abstractmethod - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: + def unique_tasklet_name(self, name: str) -> str: + pass + + def add_map( + self, + name: str, + state: dace.SDFGState, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" + unique_name = self.unique_map_name(name) + return state.add_map(unique_name, ndrange, **kwargs) + + def add_tasklet( + self, + name: str, + state: dace.SDFGState, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + + +class SDFGBuilder(DataflowBuilder, Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: pass @abc.abstractmethod @@ -52,6 +93,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: pass +@dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -66,20 +108,29 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] - - def __init__( - self, - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], - ): - self.offset_provider = offset_provider - self.symbol_types = {} - - def get_offset_provider(self) -> dict[str, gtx_common.Connectivity | gtx_common.Dimension]: - return self.offset_provider - - def get_symbol_types(self) -> dict[str, ts.FieldType | ts.ScalarType]: - return self.symbol_types + symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( + default_factory=lambda: {} + ) + map_uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="map") + ) + tesklet_uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="tlet") + ) + + def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: + assert offset in self.offset_provider + return self.offset_provider[offset] + + def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: + assert symbol_name in self.symbol_types + return self.symbol_types[symbol_name] + + def unique_map_name(self, name: str) -> str: + return f"{self.map_uids.sequential_id()}_{name}" + + def unique_tasklet_name(self, name: str) -> str: + return f"{self.tesklet_uids.sequential_id()}_{name}" def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 44c3722f58..485db254ea 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -16,7 +16,7 @@ from __future__ import annotations import dataclasses -from typing import Optional, TypeAlias +from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union import dace import dace.subsets as sbs @@ -27,6 +27,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, + gtir_to_sdfg, utility as dace_fieldview_util, ) from gt4py.next.type_system import type_specifications as ts @@ -86,7 +87,7 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] + subgraph_builder: gtir_to_sdfg.DataflowBuilder input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -94,11 +95,11 @@ def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], + subgraph_builder: gtir_to_sdfg.DataflowBuilder, ): self.sdfg = sdfg self.state = state - self.offset_provider = offset_provider + self.subgraph_builder = subgraph_builder self.input_connections = [] self.symbol_map = {} @@ -111,6 +112,29 @@ def _add_entry_memlet_path( ) -> None: self.input_connections.append((src, src_subset, dst_node, dst_conn)) + def _add_map( + self, + name: str, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Helper method to add a map with unique ame in current state.""" + return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) + + def _add_tasklet( + self, + name: str, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Helper method to add a tasklet with unique ame in current state.""" + return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs) + def _get_tasklet_result( self, dtype: dace.typeclass, @@ -175,9 +199,9 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) out_connector = "result" - tasklet_node = self.state.add_tasklet( + tasklet_node = self._add_tasklet( builtin_name, - node_connections.keys(), + set(node_connections.keys()), {out_connector}, "{} = {}".format(out_connector, code), ) @@ -226,7 +250,7 @@ def visit_Lambda( if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node output_dtype = output_expr.node.desc(self.sdfg).dtype - tasklet_node = self.state.add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_entry_memlet_path( output_expr.node, output_expr.subset, @@ -236,9 +260,7 @@ def visit_Lambda( else: # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype - tasklet_node = self.state.add_tasklet( - "write", {}, {"__out"}, f"__out = {output_expr.value}" - ) + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: From b3131dbb713fcf67691fa9bb7b536a0d20e1a58f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 12 Jul 2024 11:46:48 +0200 Subject: [PATCH 147/235] Avoid direct import of symbols from module --- .../runners/dace_fieldview/gtir_to_sdfg.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index dd8903e0c2..0a4e86fc57 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -27,7 +27,6 @@ from gt4py import eve from gt4py.eve import concepts -from gt4py.eve.utils import UIDGenerator from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -111,11 +110,11 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( default_factory=lambda: {} ) - map_uids: UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="map") + map_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) - tesklet_uids: UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="tlet") + tesklet_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: From 130c877072b3832c4ceaafbe901853915d55f6f1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 12 Jul 2024 12:40:57 +0200 Subject: [PATCH 148/235] Address review comments --- .../runners/dace_fieldview/__init__.py | 2 +- .../gtir_builtin_translators.py | 14 ++--- .../dace_fieldview/gtir_python_codegen.py | 2 +- .../runners/dace_fieldview/gtir_to_sdfg.py | 35 ++++++++++-- .../runners/dace_fieldview/workflow.py | 55 ------------------- 5 files changed, 40 insertions(+), 68 deletions(-) delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py index c39c832e88..54b2e0e29d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.program_processors.runners.dace_fieldview.workflow import build_sdfg_from_gtir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3ace4a12a5..f401f46449 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -146,7 +146,7 @@ def _create_temporary_field( return field_node, field_type -def visit_AsFieldOp( +def translate_as_field_op( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, @@ -228,7 +228,7 @@ def visit_AsFieldOp( return [(field_node, field_type)] -def visit_Cond( +def translate_cond( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, @@ -305,7 +305,7 @@ def visit_Cond( return output_nodes -def visit_SymbolRef( +def translate_symbol_ref( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, @@ -352,9 +352,9 @@ def visit_SymbolRef( if TYPE_CHECKING: - # Use type-checking to assert that all visitor functions implement the `PrimitiveTranslator` protocol + # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ - visit_AsFieldOp, - visit_Cond, - visit_SymbolRef, + translate_as_field_op, + translate_cond, + translate_symbol_ref, ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index d70169bcd1..fcb71e4e6d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -119,10 +119,10 @@ def visit_FunCall(self, node: gtir.FunCall) -> str: raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") +get_source = PythonCodegen.apply """ Specialized visit method for symbolic expressions. Returns: A string containing the Python code corresponding to a symbolic expression """ -get_source = PythonCodegen.apply diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 0a4e86fc57..7468056f0c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -305,9 +305,9 @@ def visit_FunCall( ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.visit_AsFieldOp(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_as_field_op(node, sdfg, head_state, self) elif cpm.is_call_to(node.fun, "cond"): - return gtir_builtin_translators.visit_Cond(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_cond(node, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -325,7 +325,7 @@ def visit_Literal( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) def visit_SymRef( self, @@ -333,4 +333,31 @@ def visit_SymRef( sdfg: dace.SDFG, head_state: dace.SDFGState, ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.visit_SymbolRef(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + + +def build_sdfg_from_gtir( + program: gtir.Program, + offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], +) -> dace.SDFG: + """ + Receives a GTIR program and lowers it to a DaCe SDFG. + + The lowering to SDFG requires that the program node is type-annotated, therefore this function + runs type ineference as first step. + As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. + + Arguments: + program: The GTIR program node to be lowered to SDFG + offset_provider: The definitions of offset providers used by the program node + + Returns: + An SDFG in the DaCe canonical form (simplified) + """ + sdfg_genenerator = GTIRToSDFG(offset_provider) + # TODO: run type inference on the `program` node before passing it to `GTIRToSDFG` + sdfg = sdfg_genenerator.visit(program) + assert isinstance(sdfg, dace.SDFG) + + sdfg.simplify() + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py deleted file mode 100644 index eabfa8f713..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ /dev/null @@ -1,55 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -""" -Contains definitions of the workflow steps for GTIR programs with dace as backend for optimization and code generation. - -Note: this module covers the fieldview flavour of GTIR. -""" - -from __future__ import annotations - -import dace - -from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir -from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_to_sdfg as gtir_dace_translator, -) - - -def build_sdfg_from_gtir( - program: gtir.Program, - offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension], -) -> dace.SDFG: - """ - Receives a GTIR program and lowers it to a DaCe SDFG. - - The lowering to SDFG requires that the program node is type-annotated, therefore this function - runs type ineference as first step. - As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. - - Arguments: - program: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node - - Returns: - An SDFG in the DaCe canonical form (simplified) - """ - sdfg_genenerator = gtir_dace_translator.GTIRToSDFG(offset_provider) - # TODO: run type inference on the `program` node before passing it to `GTIRToSDFG` - sdfg = sdfg_genenerator.visit(program) - assert isinstance(sdfg, dace.SDFG) - - sdfg.simplify() - return sdfg From fb2ba90086badb1fd23b747d1fe515396206341d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 13:19:19 +0200 Subject: [PATCH 149/235] Started with an untested map promoted. --- .../transformations/__init__.py | 2 + .../transformations/map_promoter.py | 196 ++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 46fe340027..5bcddf085c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -13,9 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from .auto_opt import dace_auto_optimize, gt_auto_optimize +from .map_seriall_fusion import SerialMapFusion __all__ = [ "dace_auto_optimize", "gt_auto_optimize", + "SerialMapFusion", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py new file mode 100644 index 0000000000..6ee9b4236c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -0,0 +1,196 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Mapping, Optional, Sequence, Union + +import dace +from dace import properties, subsets, transformation +from dace.sdfg import SDFG, SDFGState, nodes + + +@properties.make_properties +class BaseMapPromoter(transformation.SingleStateTransformation): + """Base transformation to add certain missing dimension to a map. + + By adding certain dimension to a map it will became possible to fuse them. + This class acts as a base and the actual matching and checking must be + implemented by a concrete implementation. + + In order to properly work, the parameters of `source_map` must be a strict + superset of the ones of `map_to_promote`. Furthermore, this transformation + builds upon the structure defined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + Thus it only checks the name of the parameters. + + Attributes: + map_to_promote: This is the map entry that should be promoted, i.e. dimensions + will be added such that its parameter matches `source_map`. + source_map: The map entry node that describes how `map_to_promote` should + look after the promotion. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Note: + This ignores tiling. + This only works with constant sized maps. + """ + + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are on the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + + # Pattern Matching + map_to_promote = transformation.transformation.PatternNode(nodes.MapEntry) + source_map = transformation.transformation.PatternNode(nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + raise TypeError("You must implement 'expressions' yourself.") + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if only_inner_maps and only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Perform some basic structural tests on the map. + + A subclass should call this function before checking anything else. If a + subclass has not called this function, the behaviour will be undefined. + The function checks: + - If the map to promote is in the right scope (it is not required that + the two maps are in the same scope). + - If the parameter of the second map are compatible with each other. + """ + map_to_promote_entry: nodes.MapEntry = self.map_to_promote + map_to_promote: nodes.Map = map_to_promote_entry.map + source_map_entry: nodes.MapEntry = self.source_map + source_map: nodes.Map = source_map_entry.map + + # Test the scope of the promotee. + if self.only_inner_maps or self.only_toplevel_maps: + scopeDict: Mapping[nodes.Node, Union[nodes.Node, None]] = graph.scope_dict() + if self.only_inner_maps and (scopeDict[map_to_promote_entry] is None): + return False + if self.only_toplevel_maps and (scopeDict[map_to_promote_entry] is not None): + return False + + # Test if the map ranges are compatible with each other. + if not self.missing_map_params( + map_to_promote=map_to_promote, + source_map=source_map, + be_strict=True, + ): + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map Promoting. + + Add all parameters that `self.source_map` has but `self.map_to_promote` + lacks to `self.map_to_promote` the range of these new dimensions is taken + from the source map. + The order of the parameters of these new dimensions is undetermined. + """ + map_to_promote: nodes.Map = self.map_to_promote.map + source_map: nodes.Map = self.source_map.map + source_params: Sequence[str] = source_map.params + source_ranges: subsets.Range = source_map.range + + missing_params: Sequence[str] = self.missing_map_params( # type: ignore[assignment] # Will never be `None` + map_to_promote=map_to_promote, + source_map=source_map, + be_strict=False, + ) + + # Maps the map parameter of the source map to its index, i.e. which map + # parameter it is. + map_source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_params)} + + promoted_params = list(map_to_promote.params) + promoted_ranges = list(map_to_promote.range.ranges) + + for missing_param in missing_params: + promoted_params.append(missing_param) + promoted_ranges.append(source_ranges[map_source_param_to_idx[missing_param]]) + + # Now update the map properties + # This action will also remove the tiles + map_to_promote.range = subsets.Range(promoted_ranges) + map_to_promote.params = promoted_params + + def missing_map_params( + self, + map_to_promote: nodes.Map, + source_map: nodes.Map, + be_strict: bool = True, + ) -> Sequence[str] | None: + """Returns the parameter that are missing in the map that should be promoted. + + The returned sequence is empty if they are already have the same parameters. + The function will return `None` is promoting is not possible. + + Args: + map_to_promote: The map that should be promoted. + source_map: The map acting as template. + be_strict: Ensure that the ranges that are already there are correct. + """ + source_params: set[str] = set(source_map.params) + curr_params: set[str] = set(map_to_promote.params) + + # The promotion can only work if the source map's parameters + # if a superset of the ones the map that should be promoted is. + if not source_params.issuperset(curr_params): + return None + + if be_strict: + # Check if the parameters that are already in the map to promote have + # the same range as in the source map. + source_ranges: subsets.Range = source_map.range + curr_ranges: subsets.Range = map_to_promote.range + curr_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(map_to_promote.params)} + source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_map.params)} + for param_to_check in curr_params: + curr_range = curr_ranges[curr_param_to_idx[param_to_check]] + source_range = source_ranges[source_param_to_idx[param_to_check]] + if curr_range != source_range: + return None + return list(source_params - curr_params) From 52c1d019e62dd0c2cc4ab951e771aad72d04ae18 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 13:32:55 +0200 Subject: [PATCH 150/235] Updated the tests, but it still does not work. --- my_playground/map_fusion_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py index 69bef188c5..52c1d44c24 100644 --- a/my_playground/map_fusion_test.py +++ b/my_playground/map_fusion_test.py @@ -328,12 +328,17 @@ def shifting(): IOffset = 3 domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size"), + im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "red_size"), im.call("named_range")(itir.AxisLiteral(value=JDim.value), 0, "size"), ) stencil1 = im.call( im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + im.lambda_("a", "b")( + im.plus( + im.deref("a"), + im.deref("b"), + ), + ), domain, ) )( @@ -354,6 +359,7 @@ def shifting(): itir.Sym(id="y", type=IJFTYPE), itir.Sym(id="z", type=IJFTYPE), itir.Sym(id="size", type=SIZE_TYPE), + itir.Sym(id="red_size", type=SIZE_TYPE), ], declarations=[], body=[ @@ -385,6 +391,7 @@ def shifting(): "y": y, "z": z, "size": N, + "red_size": N - IOffset, } return_names = ["z"] @@ -468,9 +475,9 @@ def non_zero_start(): if "__main__" == __name__: + shifting() exclusive_only() exclusive_only_2() intermediate_branch() - # shifting() non_zero_start() print("SUCCESS") From 42f4aba734e32ed65beb7f9e72c337c65205a125 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 13:42:43 +0200 Subject: [PATCH 151/235] Now the shift test works too. --- my_playground/map_fusion_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py index 52c1d44c24..15f30a0d40 100644 --- a/my_playground/map_fusion_test.py +++ b/my_playground/map_fusion_test.py @@ -383,8 +383,8 @@ def shifting(): x = np.random.rand(N + 2 * IOffset, N) y = np.random.rand(N, N) - z = np.empty_like(y) - ref = x[IOffset : (IOffset + N), :] + y + z = np.zeros((N - IOffset, N)) + ref = x[IOffset:N, :] + y[: (N - IOffset), :] args = { "x": x, From 84b2ba72eccac79e5ea92f62fb33772530930584 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 14:50:29 +0200 Subject: [PATCH 152/235] Added some more checking functionality to the base promoter. That thing is now also able to see if a promotion is allowed or not. --- .../transformations/map_promoter.py | 90 +++++++++++++++---- .../transformations/map_seriall_fusion.py | 2 - 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 6ee9b4236c..d8638edd06 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -32,15 +32,15 @@ class BaseMapPromoter(transformation.SingleStateTransformation): builds upon the structure defined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). Thus it only checks the name of the parameters. - Attributes: - map_to_promote: This is the map entry that should be promoted, i.e. dimensions - will be added such that its parameter matches `source_map`. - source_map: The map entry node that describes how `map_to_promote` should - look after the promotion. + To influence what to promote the user must implement the `map_to_promote()` + and `source_map()` must be implemented. They have to return the map entry node. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + promote_vertical: If `True` promote vertical dimensions; `True` by default. + promote_local: If `True` promote local dimensions; `True` by default. + promote_horizontal: If `True` promote horizontal dimensions; `False` by default. Note: This ignores tiling. @@ -59,10 +59,37 @@ class BaseMapPromoter(transformation.SingleStateTransformation): allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) + promote_vertical = properties.Property( + dtype=bool, + default=True, + desc="If `True` promote vertical dimensions.", + ) + promote_local = properties.Property( + dtype=bool, + default=True, + desc="If `True` promote local dimensions.", + ) + promote_horizontal = properties.Property( + dtype=bool, + default=False, + desc="If `True` promote horizontal dimensions.", + ) - # Pattern Matching - map_to_promote = transformation.transformation.PatternNode(nodes.MapEntry) - source_map = transformation.transformation.PatternNode(nodes.MapEntry) + def map_to_promote( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> nodes.MapEntry: + """Returns the map entry that should be promoted.""" + raise NotImplementedError(f"{type(self).__name__} must implement 'map_to_promote'.") + + def source_map( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> nodes.MapEntry: + """Returns the map entry that is used as source/template.""" + raise NotImplementedError(f"{type(self).__name__} must implement 'source_map'.") @classmethod def expressions(cls) -> Any: @@ -72,6 +99,9 @@ def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, + promote_local: Optional[bool] = None, + promote_vertical: Optional[bool] = None, + promote_horizontal: Optional[bool] = None, *args: Any, **kwargs: Any, ) -> None: @@ -80,6 +110,12 @@ def __init__( self.only_inner_maps = bool(only_inner_maps) if only_toplevel_maps is not None: self.only_toplevel_maps = bool(only_toplevel_maps) + if promote_local is not None: + self.promote_local = bool(promote_local) + if promote_vertical is not None: + self.promote_vertical = bool(promote_vertical) + if promote_horizontal is not None: + self.promote_horizontal = bool(promote_horizontal) if only_inner_maps and only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") @@ -98,10 +134,11 @@ def can_be_applied( - If the map to promote is in the right scope (it is not required that the two maps are in the same scope). - If the parameter of the second map are compatible with each other. + - If a dimension would be promoted that should not. """ - map_to_promote_entry: nodes.MapEntry = self.map_to_promote + map_to_promote_entry: nodes.MapEntry = self.map_to_promote(state=graph, sdfg=sdfg) map_to_promote: nodes.Map = map_to_promote_entry.map - source_map_entry: nodes.MapEntry = self.source_map + source_map_entry: nodes.MapEntry = self.source_map(state=graph, sdfg=sdfg) source_map: nodes.Map = source_map_entry.map # Test the scope of the promotee. @@ -113,10 +150,25 @@ def can_be_applied( return False # Test if the map ranges are compatible with each other. - if not self.missing_map_params( + params_to_promote: list[str] | None = self.missing_map_params( map_to_promote=map_to_promote, source_map=source_map, be_strict=True, + ) + if not params_to_promote: + return False + + # Now we must check if there are dimensions that we do not want to promote. + if (not self.promote_local) and any( + param.endswith("__gtx_localdim") for param in params_to_promote + ): + return False + if (not self.promote_vertical) and any( + param.endswith("__gtx_vertical") for param in params_to_promote + ): + return False + if (not self.promote_horizontal) and any( + param.endswith("__gtx_horizontal") for param in params_to_promote ): return False @@ -130,8 +182,8 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: from the source map. The order of the parameters of these new dimensions is undetermined. """ - map_to_promote: nodes.Map = self.map_to_promote.map - source_map: nodes.Map = self.source_map.map + map_to_promote: nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map + source_map: nodes.Map = self.source_map(state=graph, sdfg=sdfg).map source_params: Sequence[str] = source_map.params source_ranges: subsets.Range = source_map.range @@ -162,7 +214,7 @@ def missing_map_params( map_to_promote: nodes.Map, source_map: nodes.Map, be_strict: bool = True, - ) -> Sequence[str] | None: + ) -> list[str] | None: """Returns the parameter that are missing in the map that should be promoted. The returned sequence is empty if they are already have the same parameters. @@ -173,12 +225,12 @@ def missing_map_params( source_map: The map acting as template. be_strict: Ensure that the ranges that are already there are correct. """ - source_params: set[str] = set(source_map.params) - curr_params: set[str] = set(map_to_promote.params) + source_params_set: set[str] = set(source_map.params) + curr_params_set: set[str] = set(map_to_promote.params) # The promotion can only work if the source map's parameters # if a superset of the ones the map that should be promoted is. - if not source_params.issuperset(curr_params): + if not source_params_set.issuperset(curr_params_set): return None if be_strict: @@ -188,9 +240,9 @@ def missing_map_params( curr_ranges: subsets.Range = map_to_promote.range curr_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(map_to_promote.params)} source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_map.params)} - for param_to_check in curr_params: + for param_to_check in curr_params_set: curr_range = curr_ranges[curr_param_to_idx[param_to_check]] source_range = source_ranges[source_param_to_idx[param_to_check]] if curr_range != source_range: return None - return list(source_params - curr_params) + return list(source_params_set - curr_params_set) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index d23e1e758c..7eafc67e6d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -77,8 +77,6 @@ def expressions(cls) -> Any: """ return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - # end def: expressions - def can_be_applied( self, graph: Union[SDFGState, SDFG], From 24bde919713880006077e5874103033299499dfc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 15:09:11 +0200 Subject: [PATCH 153/235] Added a concrete promoter. It works for the case of serial maps. --- .../transformations/map_promoter.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index d8638edd06..da10af434a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -246,3 +246,57 @@ def missing_map_params( if curr_range != source_range: return None return list(source_params_set - curr_params_set) + + +@properties.make_properties +class SerialMapPromoter(BaseMapPromoter): + """This class promotes serial maps, such that they can be fused.""" + + # Pattern Matching + exit_first_map = transformation.transformation.PatternNode(nodes.MapExit) + access_node = transformation.transformation.PatternNode(nodes.AccessNode) + entry_second_map = transformation.transformation.PatternNode(nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + """Get the match expressions. + + The function generates two different match expression. The first match + describes the case where the top map must be promoted, while the second + case is the second/lower map must be promoted. + """ + return [ + dace.sdfg.utils.node_path_graph( + cls.exit_first_map, cls.access_node, cls.entry_second_map + ), + dace.sdfg.utils.node_path_graph( + cls.exit_first_map, cls.access_node, cls.entry_second_map + ), + ] + + def map_to_promote( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> nodes.MapEntry: + if self.expr_index == 0: + # The first the top map will be promoted. + return state.entry_node(self.exit_first_map) + assert self.expr_index == 1 + + # The second map will be promoted. + return self.entry_second_map + + def source_map( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> nodes.MapEntry: + """Returns the map entry that is used as source/template.""" + if self.expr_index == 0: + # The first the top map will be promoted, so the second map is the source. + return self.entry_second_map + assert self.expr_index == 1 + + # The second map will be promoted, so the first is used as source + return state.entry_node(self.exit_first_map) From fd81e75b90addb422dd264f083d4d8e2c8a0682c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 15:10:27 +0200 Subject: [PATCH 154/235] Added a custom (okay currently not really custom) simplification pass. --- .../transformations/auto_opt.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index d6512671b5..aeca9c6d02 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -14,7 +14,7 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any +from typing import Any, Optional import dace from dace.transformation import dataflow as dace_dataflow @@ -54,6 +54,37 @@ def dace_auto_optimize( return sdfg +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[set[str]] = None, +) -> Any: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied. + + Note: + The reason for this function is that we can influence how simplify works. + Since some parts in simplify might break things in the SDFG. + """ + from dace.transformation.passes.simplify import SimplifyPass + + return SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=skip, + ).apply_pass(sdfg, {}) + + def gt_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, From 284b6a8900c49d19ffd557c576e93d79841c9d05 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 12 Jul 2024 15:11:18 +0200 Subject: [PATCH 155/235] Updated the auto fusion stuff. --- .../transformations/auto_opt.py | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index aeca9c6d02..8b5de0ce31 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -20,6 +20,7 @@ from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize +from .map_promoter import SerialMapPromoter from .map_seriall_fusion import SerialMapFusion @@ -101,21 +102,45 @@ def gt_auto_optimize( dace.Config.set("optimizer", "match_exception", value=True) # Initial cleaning - sdfg.simplify() + gt_simplify(sdfg) + + # Compute the SDFG to see if something has changed. + sdfg_hash = sdfg.hash_sdfg() + + for _ in range(100): + # Due to the structure of the generated SDFG getting rid of Maps, + # i.e. fusing them, is the best we can currently do. + if kwargs.get("use_dace_fusion", False): + sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) + else: + xform = SerialMapFusion() + sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + + sdfg.apply_transformations_repeated( + [SerialMapPromoter(promote_horizontal=False)], + validate=True, + validate_all=True, + ) + + # Maybe running the fusion has opened more opportunities. + gt_simplify(sdfg) + + # check if something has changed and if so end it here. + old_sdfg_hash = sdfg_hash + sdfg_hash = sdfg.hash_sdfg() + + if old_sdfg_hash == sdfg_hash: + break - # Due to the structure of the generated SDFG getting rid of Maps, - # i.e. fusing them, is the best we can currently do. - if kwargs.get("use_dace_fusion", False): - sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) else: - xform = SerialMapFusion() - sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + raise RuntimeWarning("Optimization of the SDFG did not converged.") # These are the part that we copy from DaCe built in auto optimization. dace_aoptimize.set_fast_implementations(sdfg, device) dace_aoptimize.make_transients_persistent(sdfg, device) dace_aoptimize.move_small_arrays_to_stack(sdfg) - sdfg.simplify() + # Final simplify + gt_simplify(sdfg) return sdfg From 033db6bb67dc31c5aa4a97cc6fc0042c850f5204 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 15 Jul 2024 08:07:28 +0200 Subject: [PATCH 156/235] Removed all my non gt4py parts and moved it to a separate repo. The transformations are still here. --- my_playground/map_fusion_test.py | 483 ----------------------------- my_playground/my_stuff.py | 236 -------------- my_playground/nabla4.py | 494 ------------------------------ my_playground/simple_icon_mesh.py | 125 -------- 4 files changed, 1338 deletions(-) delete mode 100644 my_playground/map_fusion_test.py delete mode 100644 my_playground/my_stuff.py delete mode 100644 my_playground/nabla4.py delete mode 100644 my_playground/simple_icon_mesh.py diff --git a/my_playground/map_fusion_test.py b/my_playground/map_fusion_test.py deleted file mode 100644 index 15f30a0d40..0000000000 --- a/my_playground/map_fusion_test.py +++ /dev/null @@ -1,483 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -""" -Simple tests top verify the map fusion tests. -""" - -import dace -import copy -from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners import dace_fieldview as dace_backend -from gt4py.next.type_system import type_specifications as ts -from functools import reduce -import numpy as np - -from typing import Sequence, Any - -from dace.sdfg import nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations, # noqa: F401 [unused-import] # For development. -) - -from simple_icon_mesh import ( - IDim, # Dimensions - JDim, - KDim, - EdgeDim, - VertexDim, - CellDim, - ECVDim, - E2C2VDim, - NbCells, # Constants of the size - NbEdges, - NbVertices, - E2C2VDim, # Offsets - E2C2V, - SIZE_TYPE, # Type definitions - E2C2V_connectivity, - E2ECV_connectivity, - make_syms, # Helpers -) - -# For cartesian stuff. -N = 10 -IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -IJFTYPE = ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - - -def _perform_test( - sdfg: dace.SDFG, ref: Any, return_names: Sequence[str] | str, args: dict[str, Any] -) -> dace.SDFG: - unopt_sdfg = copy.deepcopy(sdfg) - - if not isinstance(ref, list): - ref = [ref] - if isinstance(return_names, str): - return_names = [return_names] - - SYMBS = make_syms(**args) - - # Call the unoptimized version of the SDFG - unopt_sdfg(**args, **SYMBS) - unopt_res = [args[name] for name in return_names] - - assert np.allclose(ref, unopt_res), "The unoptimized verification failed." - - # Reset the results - for name in return_names: - args[name][:] = 0 - assert not np.allclose(ref, unopt_res) - - # Now perform the optimization - opt_sdfg = copy.deepcopy(sdfg) - transformations.gt_auto_optimize(opt_sdfg) - opt_sdfg.validate() - opt_sdfg(**args, **SYMBS) - opt_res = [args[name] for name in return_names] - - assert np.allclose(ref, opt_res), "The optimized verification failed." - - return opt_sdfg - - -def _count_nodes( - sdfg: dace.SDFG, - state: dace.SDFGState | None = None, - node_type: Sequence[type] | type = dace_nodes.MapEntry, -) -> int: - states = sdfg.states() if state is None else [state] - found_matches = 0 - for state_nodes in states: - for node in state_nodes.nodes(): - if isinstance(node, node_type): - found_matches += 1 - return found_matches - - -###################### -# TESTS - - -def exclusive_only(): - """Tests the sxclusive set merging mechanism only.""" - - domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") - ) - stencil1 = im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1.0)), - domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 2.0)), - domain, - ) - )("x"), - ) - - a = np.random.rand(N) - - testee = itir.Program( - id=f"sum_3fields_1", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil1, - domain=domain, - target=itir.SymRef(id="z"), - ) - ], - ) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - sdfg.validate() - - assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 3 - assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 - - a = np.random.rand(N) - res1 = np.empty_like(a) - - args = { - "x": a, - "z": res1, - "size": N, - } - return_names = ["z"] - - opt_sdfg = _perform_test( - sdfg=sdfg, - ref=a + 3.0, - return_names="z", - args=args, - ) - - assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 3 - assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 - - -def exclusive_only_2(): - domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "size") - ) - stencil1 = im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - ) - )( - "y", - im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 2.0)), - domain, - ) - )("x"), - ) - - a = np.random.rand(N) - b = np.random.rand(N) - - testee = itir.Program( - id=f"sum_3fields_1", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="y", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil1, - domain=domain, - target=itir.SymRef(id="z"), - ) - ], - ) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - sdfg.validate() - - assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 - - a = np.random.rand(N) - res1 = np.empty_like(a) - - args = { - "x": a, - "y": b, - "z": res1, - "size": N, - } - return_names = ["z"] - - opt_sdfg = _perform_test( - sdfg=sdfg, - ref=(a + b + 2.0), - return_names="z", - args=args, - ) - - assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 - - -def intermediate_branch(): - sdfg = dace.SDFG("intermediate") - state = sdfg.add_state("state") - - ac: list[nodes.AccessNode] = [] - for i in range(3): - name = "input" if i == 0 else f"output{i-1}" - sdfg.add_array( - name, - shape=(N,), - dtype=dace.float64, - transient=False, # All are global. - ) - ac.append(state.add_access(name)) - sdfg.add_array( - name="tmp", - shape=(N,), - dtype=dace.float64, - transient=True, - ) - ac.append(state.add_access("tmp")) - - state.add_mapped_tasklet( - "first_add", - map_ranges=[("i", f"0:{N}")], - code="__out = __in0 + 1.0", - inputs=dict(__in0=dace.Memlet("input[i]")), - outputs=dict(__out=dace.Memlet("tmp[i]")), - input_nodes=dict(input=ac[0]), - output_nodes=dict(tmp=ac[-1]), - external_edges=True, - ) - - for i in range(2): - state.add_mapped_tasklet( - f"level_{i}_add", - map_ranges=[("i", f"0:{N}")], - code=f"__out = __in0 + {i+3}", - inputs=dict(__in0=dace.Memlet("tmp[i]")), - outputs=dict(__out=dace.Memlet(f"output{i}[i]")), - input_nodes=dict(tmp=ac[-1]), - output_nodes={f"output{i}": ac[1 + i]}, - external_edges=True, - ) - - assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 3 - - a = np.random.rand(N) - ref0 = a + 1 + 3 - ref1 = a + 1 + 4 - - res0 = np.empty_like(a) - res1 = np.empty_like(a) - - args = { - "input": a, - "output0": res0, - "output1": res1, - } - return_names = ["output0", "output1"] - - opt_sdfg = _perform_test( - sdfg=sdfg, - ref=[ref0, ref1], - return_names=return_names, - args=args, - ) - assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 - - -def shifting(): - """Tests what happens if we have a sift.""" - - # Currently the transformer fails to parse the IR. - - IOffset = 3 - - domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), 0, "red_size"), - im.call("named_range")(itir.AxisLiteral(value=JDim.value), 0, "size"), - ) - stencil1 = im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")( - im.plus( - im.deref("a"), - im.deref("b"), - ), - ), - domain, - ) - )( - "y", - im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.deref(im.shift("IDim", IOffset)("a"))), - domain, - ) - )("x"), - ) - - testee = itir.Program( - id=f"shift_test", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IJFTYPE), - itir.Sym(id="y", type=IJFTYPE), - itir.Sym(id="z", type=IJFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), - itir.Sym(id="red_size", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil1, - domain=domain, - target=itir.SymRef(id="z"), - ) - ], - ) - - offset_provider = { - "IDim": IDim, - "JDim": JDim, - } - sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) - sdfg.validate() - - assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 - - x = np.random.rand(N + 2 * IOffset, N) - y = np.random.rand(N, N) - z = np.zeros((N - IOffset, N)) - ref = x[IOffset:N, :] + y[: (N - IOffset), :] - - args = { - "x": x, - "y": y, - "z": z, - "size": N, - "red_size": N - IOffset, - } - return_names = ["z"] - - opt_sdfg = _perform_test( - sdfg=sdfg, - ref=ref, - return_names="z", - args=args, - ) - - assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 4 - assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 - - -def non_zero_start(): - """Tests what happens if there are two maps, that does not start at zero.""" - - dom_start = 3 - domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value), dom_start, "size") - ) - stencil1 = im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 1.0)), - domain, - ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("a")(im.plus(im.deref("a"), 2.0)), - domain, - ) - )("x"), - ) - - testee = itir.Program( - id=f"non_zero_start_test", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IFTYPE), - itir.Sym(id="z", type=IFTYPE), - itir.Sym(id="size", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=stencil1, - domain=domain, - target=itir.SymRef(id="z"), - ) - ], - ) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - sdfg.validate() - - assert _count_nodes(sdfg, node_type=dace_nodes.AccessNode) == 3 - assert _count_nodes(sdfg, node_type=dace_nodes.MapEntry) == 2 - - x = np.random.rand(N) - z = np.zeros_like(x) - - args = { - "x": x, - "z": z, - "size": N, - } - ref = np.zeros_like(x) - ref[dom_start:N] = x[dom_start:N] + 3.0 - return_names = ["z"] - - opt_sdfg = _perform_test( - sdfg=sdfg, - ref=ref, - return_names="z", - args=args, - ) - - assert _count_nodes(opt_sdfg, node_type=dace_nodes.AccessNode) == 3 - assert _count_nodes(opt_sdfg, node_type=dace_nodes.MapEntry) == 1 - - -if "__main__" == __name__: - shifting() - exclusive_only() - exclusive_only_2() - intermediate_branch() - non_zero_start() - print("SUCCESS") diff --git a/my_playground/my_stuff.py b/my_playground/my_stuff.py deleted file mode 100644 index 3728d3b1bc..0000000000 --- a/my_playground/my_stuff.py +++ /dev/null @@ -1,236 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -""" -Test that ITIR can be lowered to SDFG. - -Note: this test module covers the fieldview flavour of ITIR. -""" - -import copy -from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners import dace_fieldview as dace_backend -from gt4py.next.type_system import type_specifications as ts -from functools import reduce -import numpy as np - -from simple_icon_mesh import ( - IDim, # Dimensions - JDim, - KDim, - EdgeDim, - VertexDim, - CellDim, - ECVDim, - E2C2VDim, - NbCells, # Constants of the size - NbEdges, - NbVertices, - E2C2VDim, # Offsets - E2C2V, - SIZE_TYPE, # Type definitions - E2C2V_connectivity, - E2ECV_connectivity, - make_syms, # Helpers -) - -# For cartesian stuff. -N = 10 -IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -IJFTYPE = ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - - -###################### -# TESTS - - -def gtir_copy3(): - # We can not use the size symbols inside the domain - # Because the translator complains. - - # Input domain - input_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "org_sizeI"), - im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "org_sizeJ"), - ) - - # Domain for after we have processed the IDim. - first_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "sizeI"), - im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "org_sizeJ"), - ) - - # This is the final domain, or after we have removed the JDim - final_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value=IDim.value, kind=IDim.kind), 0, "sizeI"), - im.call("named_range")(itir.AxisLiteral(value=JDim.value, kind=JDim.kind), 0, "sizeJ"), - ) - - IOffset = 1 - JOffset = 2 - - testee = itir.Program( - id="gtir_copy", - function_definitions=[], - params=[ - itir.Sym(id="x", type=IJFTYPE), - itir.Sym(id="y", type=IJFTYPE), - itir.Sym(id="sizeI", type=SIZE_TYPE), - itir.Sym(id="sizeJ", type=SIZE_TYPE), - itir.Sym(id="org_sizeI", type=SIZE_TYPE), - itir.Sym(id="org_sizeJ", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - # This processed the `JDim`, it is first because its arguments are - # evaluated before, so there the first cutting is happening. - im.call("as_fieldop")( - im.lambda_("a")( - im.deref( - im.shift("JDim", JOffset)("a") # This does not work - ) - ), - final_domain, - ) - )( - # Now here we will process the `IDim` part. - im.call( - im.call("as_fieldop")( - im.lambda_("b")( - im.deref( - im.shift("IDim", IOffset)("b") - # "b" - ) - ), - first_domain, - ) - )("x"), - ), - domain=final_domain, - target=itir.SymRef(id="y"), - ) - ], - ) - - # We only need an offset provider for the translation. - offset_provider = { - "IDim": IDim, - "JDim": JDim, - } - - sdfg = dace_backend.build_sdfg_from_gtir( - testee, - offset_provider, - ) - - output_size_I, output_size_J = 10, 10 - input_size_I, input_size_J = 20, 20 - - a = np.random.rand(input_size_I, input_size_J) - b = np.empty((output_size_I, output_size_J), dtype=np.float64) - - SYMBS = make_syms(x=a, y=b) - - sdfg( - x=a, - y=b, - sizeI=output_size_I, - sizeJ=output_size_J, - org_sizeI=input_size_I, - org_sizeJ=input_size_J, - **SYMBS, - ) - - ref = a[IOffset : (IOffset + output_size_I), JOffset : (JOffset + output_size_J)] - - assert np.all(b == ref) - assert True - - -def gtir_ecv_shift(): - # EdgeDim, E2C2VDim - domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "nedges" - ), - # im.call("named_range")(itir.AxisLiteral(value=E2C2VDim.value, kind=E2C2VDim.kind), 0, 4), - ) - - INPUT_FTYPE = ts.FieldType(dims=[ECVDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - OUTPUT_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - - testee = itir.Program( - id="gtir_shift", - function_definitions=[], - params=[ - itir.Sym(id="x", type=INPUT_FTYPE), - itir.Sym(id="y", type=OUTPUT_FTYPE), - itir.Sym(id="nedges", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - # This processed the `JDim`, it is first because its arguments are - # evaluated before, so there the first cutting is happening. - im.call("as_fieldop")( - im.lambda_("a")(im.deref(im.shift("E2ECV", 0)("a"))), - domain, - ) - )("x"), - domain=domain, - target=itir.SymRef(id="y"), - ) - ], - ) - - offset_provider = { - "E2C2V": E2C2V_connectivity, - "E2ECV": E2ECV_connectivity, - } - - sdfg = dace_backend.build_sdfg_from_gtir( - testee, - offset_provider, - ) - - a = np.random.rand(NbEdges * 4) - b = np.empty((NbEdges,), dtype=np.float64) - - call_args = { - "x": a, - "y": b, - "connectivity_E2C2V": E2C2V_connectivity.table.copy(), - "connectivity_E2ECV": E2ECV_connectivity.table.copy(), - } - - SYMBS = make_syms(**call_args) - - sdfg( - **call_args, - nedges=NbEdges, - **SYMBS, - ) - ref = a[E2ECV_connectivity.table[:, 0]] - - assert np.allclose(ref, b) - assert True - - -if "__main__" == __name__: - # gtir_copy3() - gtir_ecv_shift() diff --git a/my_playground/nabla4.py b/my_playground/nabla4.py deleted file mode 100644 index f3252fe85b..0000000000 --- a/my_playground/nabla4.py +++ /dev/null @@ -1,494 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -""" -Implementation of the Nabla4 Stencil. -""" - -import copy - -from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.ffront.fbuiltins import Field -from gt4py.next.program_processors.runners import dace_fieldview as dace_backend -from gt4py.next.type_system import type_specifications as ts - -from gt4py.next.program_processors.runners.dace_fieldview import transformations as fw_trans - -from simple_icon_mesh import ( - IDim, # Dimensions - JDim, - KDim, - EdgeDim, - VertexDim, - CellDim, - ECVDim, - E2C2VDim, - NbCells, # Constants of the size - NbEdges, - NbVertices, - E2C2VDim, # Offsets - E2C2V, - SIZE_TYPE, # Type definitions - E2C2V_connectivity, - E2ECV_connectivity, - make_syms, # Helpers -) - -from typing import Sequence, Any -from functools import reduce -import numpy as np - -import dace - -wpfloat = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) -VK_FTYPE = ts.FieldType(dims=[VertexDim, KDim], dtype=wpfloat) -EK_FTYPE = ts.FieldType(dims=[EdgeDim, KDim], dtype=wpfloat) -E_FTYPE = ts.FieldType(dims=[EdgeDim], dtype=wpfloat) -ECV_FTYPE = ts.FieldType(dims=[ECVDim], dtype=wpfloat) - - -def nabla4_np( - u_vert: Field[[EdgeDim, KDim], wpfloat], - v_vert: Field[[EdgeDim, KDim], wpfloat], - primal_normal_vert_v1: Field[[ECVDim], wpfloat], - primal_normal_vert_v2: Field[[ECVDim], wpfloat], - z_nabla2_e: Field[[EdgeDim, KDim], wpfloat], - inv_vert_vert_length: Field[[EdgeDim], wpfloat], - inv_primal_edge_length: Field[[EdgeDim], wpfloat], - # These are the offset providers - E2C2V: NeighborTable, - **kwargs, # Allows to use the same call argument object as for the SDFG -) -> Field[[EdgeDim, KDim], wpfloat]: - primal_normal_vert_v1 = primal_normal_vert_v1.reshape(E2C2V.table.shape) - primal_normal_vert_v2 = primal_normal_vert_v2.reshape(E2C2V.table.shape) - u_vert_e2c2v = u_vert[E2C2V.table] - v_vert_e2c2v = v_vert[E2C2V.table] - - xn_0 = u_vert_e2c2v[:, 2] * primal_normal_vert_v1[:, 2].reshape((-1, 1)) - xn_1 = v_vert_e2c2v[:, 2] * primal_normal_vert_v2[:, 2].reshape((-1, 1)) - xn_2 = u_vert_e2c2v[:, 3] * primal_normal_vert_v1[:, 3].reshape((-1, 1)) - xn_3 = v_vert_e2c2v[:, 3] * primal_normal_vert_v2[:, 3].reshape((-1, 1)) - nabv_norm = xn_0 + xn_1 + xn_2 + xn_3 - - N = nabv_norm - 2 * z_nabla2_e - ell_v2 = inv_vert_vert_length**2 - N_ellv2 = N * ell_v2.reshape((-1, 1)) - - xt_0 = u_vert_e2c2v[:, 0] * primal_normal_vert_v1[:, 0].reshape((-1, 1)) - xt_1 = v_vert_e2c2v[:, 0] * primal_normal_vert_v2[:, 0].reshape((-1, 1)) - xt_2 = u_vert_e2c2v[:, 1] * primal_normal_vert_v1[:, 1].reshape((-1, 1)) - xt_3 = v_vert_e2c2v[:, 1] * primal_normal_vert_v2[:, 1].reshape((-1, 1)) - nabv_tang = xt_0 + xt_1 + xt_2 + xt_3 - - T = nabv_tang - 2 * z_nabla2_e - ell_e2 = inv_primal_edge_length**2 - T_elle2 = T * ell_e2.reshape((-1, 1)) - - return 4 * (N_ellv2 + T_elle2) - - -# Dimension we operate on. -edge_k_domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" - ), - im.call("named_range")(itir.AxisLiteral(value=KDim.value, kind=KDim.kind), 0, "num_k_levels"), -) -edge_domain = im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value=EdgeDim.value, kind=EdgeDim.kind), 0, "num_edges" - ), -) - - -def shift_builder( - vert: str, - vert_idx: int, - primal: str, - primal_idx: int, -) -> itir.FunCall: - """Used to construct the shifting calculations. - - This function generates the IR for the expression: - ``` - vert[E2C2V[:, vert_idx]] * primal[E2ECV[:, primal_idx]] - ``` - """ - return im.call( - im.call("as_fieldop")( - im.lambda_("vert_shifted", "primal_shifted")( - im.multiplies_(im.deref("vert_shifted"), im.deref("primal_shifted")) - ), - edge_k_domain, - ) - )( - # arg: `vert_shifted` - im.call( - im.call("as_fieldop")( - im.lambda_("vert_no_shifted")( - im.deref(im.shift("E2C2V", vert_idx)("vert_no_shifted")) - ), - edge_k_domain, - ) - )( - vert, # arg: `vert_no_shifted` - ), - # end arg: `vert_shifted` - # arg: `primal_shifted` - im.call( - im.call("as_fieldop")( - im.lambda_("primal_no_shifted")( - im.deref(im.shift("E2ECV", primal_idx)("primal_no_shifted")) - ), - edge_domain, - ) - )( - primal, # arg: `primal_no_shifted` - ), - # end arg: `primal_shifted` - ) - - -def build_nambla4_gtir_fieldview( - num_edges: int, - num_k_levels: int, -) -> itir.Program: - """Creates the `nabla4` stencil in most extreme fieldview version as possible.""" - - nabla4prog = itir.Program( - id="nabla4_partial_fieldview", - function_definitions=[], - params=[ - itir.Sym(id="u_vert", type=VK_FTYPE), - itir.Sym(id="v_vert", type=VK_FTYPE), - itir.Sym(id="primal_normal_vert_v1", type=ECV_FTYPE), - itir.Sym(id="primal_normal_vert_v2", type=ECV_FTYPE), - itir.Sym(id="z_nabla2_e", type=EK_FTYPE), - itir.Sym(id="inv_vert_vert_length", type=E_FTYPE), - itir.Sym(id="inv_primal_edge_length", type=E_FTYPE), - itir.Sym(id="nab4", type=EK_FTYPE), - itir.Sym(id="num_edges", type=SIZE_TYPE), - itir.Sym(id="num_k_levels", type=SIZE_TYPE), - ], - declarations=[], - body=[ - itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("NpT", "const_4")( - im.multiplies_(im.deref("NpT"), im.deref("const_4")) - ), - edge_k_domain, - ) - )( - # arg: `NpT` - im.call( - im.call("as_fieldop")( - im.lambda_("N_ell2", "T_ell2")( - im.plus(im.deref("N_ell2"), im.deref("T_ell2")) - ), - edge_k_domain, - ) - )( - # arg: `N_ell2` - im.call( - im.call("as_fieldop")( - im.lambda_("ell_v2", "N")( - im.multiplies_(im.deref("N"), im.deref("ell_v2")) - ), - edge_k_domain, - ) - )( - # arg: `ell_v2` - im.call( - im.call("as_fieldop")( - im.lambda_("ell_v")( - im.multiplies_(im.deref("ell_v"), im.deref("ell_v")) - ), - edge_k_domain, - ) - )( - # arg: `ell_v` - "inv_vert_vert_length" - ), - # end arg: `ell_v2` - # arg: `N` - im.call( - im.call("as_fieldop")( - im.lambda_("xn", "z_nabla2_e2")( - im.minus(im.deref("xn"), im.deref("z_nabla2_e2")) - ), - edge_k_domain, - ) - )( - # arg: `xn` - # u_vert(E2C2V[2]) * primal_normal_vert_v1(E2ECV[2]) || nx_0 - # + v_vert(E2C2V[2]) * primal_normal_vert_v2(E2ECV[2]) || xn_1 - # + u_vert(E2C2V[3]) * primal_normal_vert_v1(E2ECV[3]) || xn_2 - # + v_vert(E2C2V[3]) * primal_normal_vert_v2(E2ECV[3]) || xn_3 - im.call( - im.call("as_fieldop")( - im.lambda_("xn_0_p_1", "xn_2_p_3")( - im.plus(im.deref("xn_0_p_1"), im.deref("xn_2_p_3")) - ), - edge_k_domain, - ) - )( - # arg: `xn_0_p_1` - im.call( - im.call("as_fieldop")( - im.lambda_("xn_0", "xn_1")( - im.plus(im.deref("xn_0"), im.deref("xn_1")) - ), - edge_k_domain, - ) - )( - shift_builder( # arg: `xn_0` - "u_vert", 2, "primal_normal_vert_v1", 2 - ), - shift_builder( # arg: `xn_1` - "v_vert", 2, "primal_normal_vert_v2", 2 - ), - ), - # end arg: `xn_0_p_1` - # arg: `xn_2_p_3` - im.call( - im.call("as_fieldop")( - im.lambda_("xn_2", "xn_3")( - im.plus(im.deref("xn_2"), im.deref("xn_3")) - ), - edge_k_domain, - ) - )( - shift_builder( # arg: `xn_2` - "u_vert", 3, "primal_normal_vert_v1", 3 - ), - shift_builder( # arg: `xn_3` - "v_vert", 3, "primal_normal_vert_v2", 3 - ), - ), - # end arg: `xn_2_p_3` - ), - # end arg: `xn` - # arg: `z_nabla2_e2` - im.call( - im.call("as_fieldop")( - im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_( - im.deref("z_nabla2_e"), im.deref("const_2") - ) - ), - edge_k_domain, - ) - )( - # arg: `z_nabla2_e` - "z_nabla2_e", - # arg: `const_2` - 2.0, - ), - # end arg: `z_nabla2_e2` - ), - # end arg: `N` - ), - # end arg: `N_ell2` - # arg: `T_ell2` - im.call( - im.call("as_fieldop")( - im.lambda_("ell_e2", "T")( - im.multiplies_(im.deref("T"), im.deref("ell_e2")) - ), - edge_k_domain, - ) - )( - # arg: `ell_e2` - im.call( - im.call("as_fieldop")( - im.lambda_("ell_e")( - im.multiplies_(im.deref("ell_e"), im.deref("ell_e")) - ), - edge_k_domain, - ) - )( - # arg: `ell_e` - "inv_primal_edge_length" - ), - # end arg: `ell_e2` - # arg: `T` - im.call( - im.call("as_fieldop")( - im.lambda_("xt", "z_nabla2_e2")( - im.minus(im.deref("xt"), im.deref("z_nabla2_e2")) - ), - edge_k_domain, - ) - )( - # arg: `xt` - # u_vert(E2C2V[0]) * primal_normal_vert_v1(E2ECV[0]) || nx_0 - # + v_vert(E2C2V[0]) * primal_normal_vert_v2(E2ECV[0]) || xt_1 - # + u_vert(E2C2V[1]) * primal_normal_vert_v1(E2ECV[1]) || xt_2 - # + v_vert(E2C2V[1]) * primal_normal_vert_v2(E2ECV[1]) || xt_3 - im.call( - im.call("as_fieldop")( - im.lambda_("xt_0_p_1", "xn_2_p_3")( - im.plus(im.deref("xt_0_p_1"), im.deref("xn_2_p_3")) - ), - edge_k_domain, - ) - )( - # arg: `xt_0_p_1` - im.call( - im.call("as_fieldop")( - im.lambda_("xt_0", "xn_1")( - im.plus(im.deref("xt_0"), im.deref("xn_1")) - ), - edge_k_domain, - ) - )( - shift_builder( # arg: `xt_0` - "u_vert", 0, "primal_normal_vert_v1", 0 - ), - shift_builder( # arg: `xt_1` - "v_vert", 0, "primal_normal_vert_v2", 0 - ), - ), - # end arg: `xt_0_p_1` - # arg: `xt_2_p_3` - im.call( - im.call("as_fieldop")( - im.lambda_("xt_2", "xn_3")( - im.plus(im.deref("xt_2"), im.deref("xn_3")) - ), - edge_k_domain, - ) - )( - shift_builder( # arg: `xt_2` - "u_vert", 1, "primal_normal_vert_v1", 1 - ), - shift_builder( # arg: `xt_3` - "v_vert", 1, "primal_normal_vert_v2", 1 - ), - ), - # end arg: `xt_2_p_3` - ), - # end arg: `xt` - # arg: `z_nabla2_e2` - im.call( - im.call("as_fieldop")( - im.lambda_("z_nabla2_e", "const_2")( - im.multiplies_( - im.deref("z_nabla2_e"), im.deref("const_2") - ) - ), - edge_k_domain, - ) - )( - # arg: `z_nabla2_e` - "z_nabla2_e", - # arg: `const_2` - 2.0, - ), - ), - # end arg: `T` - ), - # end arg: `T_ell2` - ), - # end arg: `NpT` - # arg: `const_4` - 4.0, - ), - domain=edge_k_domain, - target=itir.SymRef(id="nab4"), - ) - ], - ) - - return nabla4prog - - -def verify_nabla4( - version: str, -): - num_edges = NbEdges - num_vertices = NbVertices - num_k_levels = 10 - - if version == "fieldview": - nabla4prog = build_nambla4_gtir_fieldview( - num_edges=num_edges, - num_k_levels=num_k_levels, - ) - - elif version == "inline": - raise NotImplementedError("Inline version is no longer supported.") - - else: - raise ValueError(f"The version `{version}` is now known.") - - offset_provider = { - "E2C2V": E2C2V_connectivity, - "E2ECV": E2ECV_connectivity, - } - - u_vert = np.random.rand(num_vertices, num_k_levels) - v_vert = np.random.rand(num_vertices, num_k_levels) - primal_normal_vert_v1 = np.random.rand(num_edges * 4) - primal_normal_vert_v2 = np.random.rand(num_edges * 4) - - z_nabla2_e = np.random.rand(num_edges, num_k_levels) - inv_vert_vert_length = np.random.rand(num_edges) - inv_primal_edge_length = np.random.rand(num_edges) - nab4 = np.empty((num_edges, num_k_levels), dtype=np.float64) - - sdfg = dace_backend.build_sdfg_from_gtir(nabla4prog, offset_provider) - - call_args = dict( - z_nabla2_e=z_nabla2_e, - inv_vert_vert_length=inv_vert_vert_length, - inv_primal_edge_length=inv_primal_edge_length, - nab4=nab4, - num_edges=num_edges, - num_k_levels=num_k_levels, - u_vert=u_vert, - v_vert=v_vert, - primal_normal_vert_v1=primal_normal_vert_v1, - primal_normal_vert_v2=primal_normal_vert_v2, - ) - call_args.update({f"connectivity_{k}": v.table.copy() for k, v in offset_provider.items()}) - - SYMBS = make_syms(**call_args) - - org_sdfg = copy.deepcopy(sdfg) - - for i in range(2): - sdfg = copy.deepcopy(org_sdfg) - if i != 0: - fw_trans.gt_auto_optimize(sdfg) - - sdfg.view() - - sdfg(**call_args, **SYMBS) - ref = nabla4_np(**call_args, **offset_provider) - assert np.allclose(ref, nab4) - nab4[:] = 0 - if i == 0: - print(f"Version({version} | unoptimized): Succeeded") - else: - print(f"Version({version} | optimized): Succeeded") - - -if "__main__" == __name__: - verify_nabla4("fieldview") diff --git a/my_playground/simple_icon_mesh.py b/my_playground/simple_icon_mesh.py deleted file mode 100644 index de4b8c2a35..0000000000 --- a/my_playground/simple_icon_mesh.py +++ /dev/null @@ -1,125 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Reimplementation of the simple icon grid for testing.""" - -import numpy as np - -from gt4py.next.common import DimensionKind -from gt4py.next.ffront.fbuiltins import Dimension, FieldOffset -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider -from gt4py.next.type_system import type_specifications as ts - -IDim = Dimension("IDim") -JDim = Dimension("JDim") - -KDim = Dimension("K", kind=DimensionKind.VERTICAL) -EdgeDim = Dimension("Edge") -CellDim = Dimension("Cell") -VertexDim = Dimension("Vertex") -ECVDim = Dimension("ECV") -E2C2VDim = Dimension("E2C2V", DimensionKind.LOCAL) - -E2ECV = FieldOffset("E2ECV", source=ECVDim, target=(EdgeDim, E2C2VDim)) -E2C2V = FieldOffset("E2C2V", source=VertexDim, target=(EdgeDim, E2C2VDim)) - -Koff = FieldOffset("Koff", source=KDim, target=(KDim,)) - -NbCells = 18 -NbEdges = 27 -NbVertices = 9 - -SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) - -e2c2v_table = np.asarray( - [ - [0, 1, 4, 6], # 0 - [0, 4, 1, 3], # 1 - [0, 3, 4, 2], # 2 - [1, 2, 5, 7], # 3 - [1, 5, 2, 4], # 4 - [1, 4, 5, 0], # 5 - [2, 0, 3, 8], # 6 - [2, 3, 5, 0], # 7 - [2, 5, 1, 3], # 8 - [3, 4, 0, 7], # 9 - [3, 7, 4, 6], # 10 - [3, 6, 7, 5], # 11 - [4, 5, 8, 1], # 12 - [4, 8, 7, 5], # 13 - [4, 7, 3, 8], # 14 - [5, 3, 6, 2], # 15 - [6, 5, 3, 8], # 16 - [8, 5, 6, 4], # 17 - [6, 7, 3, 1], # 18 - [6, 1, 7, 0], # 19 - [6, 0, 1, 8], # 20 - [7, 8, 2, 4], # 21 - [7, 2, 8, 1], # 22 - [7, 1, 2, 6], # 23 - [8, 6, 0, 5], # 24 - [8, 0, 6, 2], # 25 - [8, 2, 0, 6], # 26 - ] -) - -E2C2V_connectivity = NeighborTableOffsetProvider( - # I do not understand the ordering here? Why is `Edge` the source if you read - # it right to left? - e2c2v_table, - EdgeDim, - VertexDim, - e2c2v_table.shape[1], -) - - -def _make_E2ECV_connectivity(E2C2V_connectivity: NeighborTableOffsetProvider): - # Implementation is adapted from icon's `_get_offset_provider_for_sparse_fields()` - e2c2v_table = E2C2V_connectivity.table - t = np.arange(e2c2v_table.shape[0] * e2c2v_table.shape[1]).reshape(e2c2v_table.shape) - return NeighborTableOffsetProvider(t, EdgeDim, ECVDim, t.shape[1]) - - -E2ECV_connectivity = _make_E2ECV_connectivity(E2C2V_connectivity) - - -def dace_strides( - array: np.ndarray, - name: None | str = None, -) -> tuple[int, ...] | dict[str, int]: - if not hasattr(array, "strides"): - return {} - strides = array.strides - if hasattr(array, "itemsize"): - strides = tuple(stride // array.itemsize for stride in strides) - if name is not None: - strides = {f"__{name}_stride_{i}": stride for i, stride in enumerate(strides)} - return strides - - -def dace_shape( - array: np.ndarray, - name: str, -) -> dict[str, int]: - if not hasattr(array, "shape"): - return {} - return {f"__{name}_size_{i}": size for i, size in enumerate(array.shape)} - - -def make_syms(**kwargs: np.ndarray) -> dict[str, int]: - SYMBS = {} - for name, array in kwargs.items(): - SYMBS.update(**dace_shape(array, name)) - SYMBS.update(**dace_strides(array, name)) - return SYMBS From 0fddb8d3efdf78e3c3242cfcba50048a1f9f26b9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 15 Jul 2024 08:50:24 +0200 Subject: [PATCH 157/235] Added a transformation to bring the map iteration indexes in the correct order. --- .../transformations/__init__.py | 4 + .../transformations/map_orderer.py | 125 ++++++++++++++++++ 2 files changed, 129 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 5bcddf085c..0451c205ec 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -13,6 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from .auto_opt import dace_auto_optimize, gt_auto_optimize +from .map_orderer import MapIterationOrder +from .map_promoter import SerialMapPromoter from .map_seriall_fusion import SerialMapFusion @@ -20,4 +22,6 @@ "dace_auto_optimize", "gt_auto_optimize", "SerialMapFusion", + "SerialMapPromoter", + "MapIterationOrder", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py new file mode 100644 index 0000000000..20bb46ee47 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -0,0 +1,125 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, Union + +import dace +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, nodes + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util + + +@properties.make_properties +class MapIterationOrder(transformation.SingleStateTransformation): + """Modify the order of the iteration variables. + + The transformation modifies the order in which the map variables are processed. + This transformation is restricted in the sense, that it is only possible + to set the "first map" variable, i.e. the one that is associated with the + `x` dimension in a thread block and where the input memory should have stride 1. + + If the transformation modifies then the map variable corresponding to + `self.leading_dim` will be at the correct place, the order of all other + map variable is unspecific. Otherwise the map is unmodified. + + Args: + leading_dim: A GT4Py dimension object that identifies the dimension that + should be used. + + Note: + This transformation does follow the rules outlines [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + + Todo: + - Extend that different dimensions can be specified to be leading + dimensions, with some priority mechanism. + - Maybe also process the parameters to bring them in a canonical order. + """ + + leading_dim = properties.Property( + dtype=gtx_common.Dimension, + allow_none=True, + desc="Dimension that should become the leading dimension.", + ) + + map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + leading_dim: Optional[gtx_common.Dimension] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if leading_dim is not None: + self.leading_dim = leading_dim + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the map can be reordered. + + Essentially the function checks if the selected dimension is inside the map, + and if so, if it is on the right place. + """ + + if self.leading_dim is None: + return False + + map_entry: nodes.MapEntry = self.map_entry + map_params: Sequence[str] = map_entry.map.params + map_var: str = dace_fieldview_util.get_map_variable(self.leading_dim) + + if map_var not in map_params: + return False + if map_params[-1] == map_var: # Already at the end; `-1` is correct! + return False + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the actual parameter reordering. + + The function will move the map variable, that corresponds to + `self.leading_dim` at the end. It will not put it at the front, because + DaCe.codegen processes the variables in revers order and smashes all + the excess parameter into the last CUDA dimension. + """ + map_entry: nodes.MapEntry = self.map_entry + map_params: list[str] = map_entry.map.params + map_var: str = dace_fieldview_util.get_map_variable(self.leading_dim) + + # This implementation will just swap the variable that is currently the last + # with the one that should be the last. + dst_idx = -1 + src_idx = map_params.index(map_var) + + for to_process in [ + map_entry.map.params, + map_entry.map.range.ranges, + map_entry.map.range.tile_sizes, + ]: + assert isinstance(to_process, list) + src_val = to_process[src_idx] + dst_val = to_process[dst_idx] + to_process[dst_idx] = src_val + to_process[src_idx] = dst_val From c4f64a419ef8dc1aad0e06eccf9cfee58c0919ba Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 15 Jul 2024 09:08:45 +0200 Subject: [PATCH 158/235] Updated the auto optimizer. --- .../transformations/auto_opt.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 8b5de0ce31..8969dc6d3e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -20,8 +20,10 @@ from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize -from .map_promoter import SerialMapPromoter -from .map_seriall_fusion import SerialMapFusion +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) def dace_auto_optimize( @@ -89,6 +91,7 @@ def gt_simplify( def gt_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, + leading_dim: Optional[gtx_common.Dimension] = None, **kwargs: Any, ) -> dace.SDFG: """Performs GT4Py specific optimizations in place. @@ -113,11 +116,11 @@ def gt_auto_optimize( if kwargs.get("use_dace_fusion", False): sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) else: - xform = SerialMapFusion() + xform = gtx_transformations.SerialMapFusion() sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) sdfg.apply_transformations_repeated( - [SerialMapPromoter(promote_horizontal=False)], + [gtx_transformations.SerialMapPromoter(promote_horizontal=False)], validate=True, validate_all=True, ) @@ -135,6 +138,15 @@ def gt_auto_optimize( else: raise RuntimeWarning("Optimization of the SDFG did not converged.") + # After we have optimized the SDFG as good as we can, we will now do some + # lower level optimization. + if leading_dim is not None: + sdfg.apply_transformations_once_everywhere( + gtx_transformations.MapIterationOrder( + leading_dim=leading_dim, + ) + ) + # These are the part that we copy from DaCe built in auto optimization. dace_aoptimize.set_fast_implementations(sdfg, device) dace_aoptimize.make_transients_persistent(sdfg, device) From 6143b95af5a3a9e4be27542ec8f8e3dd1f26c4a8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 11:15:07 +0200 Subject: [PATCH 159/235] Made the `gt_simplify()` function aviable. --- .../runners/dace_fieldview/transformations/__init__.py | 3 ++- .../runners/dace_fieldview/transformations/auto_opt.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 0451c205ec..d9422cc42f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from .auto_opt import dace_auto_optimize, gt_auto_optimize +from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter from .map_seriall_fusion import SerialMapFusion @@ -21,6 +21,7 @@ __all__ = [ "dace_auto_optimize", "gt_auto_optimize", + "gt_simplify", "SerialMapFusion", "SerialMapPromoter", "MapIterationOrder", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 8969dc6d3e..24f31db939 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -149,8 +149,8 @@ def gt_auto_optimize( # These are the part that we copy from DaCe built in auto optimization. dace_aoptimize.set_fast_implementations(sdfg, device) - dace_aoptimize.make_transients_persistent(sdfg, device) dace_aoptimize.move_small_arrays_to_stack(sdfg) + dace_aoptimize.make_transients_persistent(sdfg, device) # Final simplify gt_simplify(sdfg) From 00aa64cb2792ceae0ce49ccee54b8814b391911d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 13:55:25 +0200 Subject: [PATCH 160/235] Fixed a porting bug. --- .../dace_fieldview/transformations/map_seriall_fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index 7eafc67e6d..b4fc6b5985 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -303,7 +303,7 @@ def handle_intermediate_set( assert (pre_exit_edge.data.subset.num_elements() > 1) or all( x == 1 for x in new_inter_shape ) - isScalar: bool = False + is_scalar = False new_inter_name, new_inter_desc = sdfg.add_transient( new_inter_name, shape=new_inter_shape, @@ -420,7 +420,7 @@ def handle_intermediate_set( consumer_edge = consumer_tree.edge assert consumer_edge.data.data == inter_name consumer_edge.data.data = new_inter_name - if isScalar: + if is_scalar: consumer_edge.data.src_subset = "0" else: if consumer_edge.data.subset is not None: From b67f0c08c36b3258e68cb6324f848e325883f973 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 13:56:07 +0200 Subject: [PATCH 161/235] Fixed an edge case in the computation of the output partition if teh intermediate node is a scalar only. --- .../dace_fieldview/transformations/map_fusion_helper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 6ef4e473d2..138164bdb4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -466,7 +466,9 @@ def partition_first_outputs( consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, feed_edge in consumers: # TODO(phimuell): Improve this approximation. - if feed_edge.data.num_elements() == intermediate_size: + if ( + intermediate_size != 1 + ) and feed_edge.data.num_elements() == intermediate_size: return None if consumer_node is map_entry_2: # Dynamic map range. return None From b447c2aa639833301ad8aebccc5472271af7f296 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 13:58:38 +0200 Subject: [PATCH 162/235] Added a map promoter that is able to promote trivial maps that are generated during the GPU mode. --- .../transformations/__init__.py | 3 +- .../transformations/map_promoter.py | 84 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index d9422cc42f..a84d4bcfc8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -14,7 +14,7 @@ from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter +from .map_promoter import SerialMapPromoter, SerialMapPromoterGPU from .map_seriall_fusion import SerialMapFusion @@ -24,5 +24,6 @@ "gt_simplify", "SerialMapFusion", "SerialMapPromoter", + "SerialMapPromoterGPU", "MapIterationOrder", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index da10af434a..51d272b77e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import copy from typing import Any, Mapping, Optional, Sequence, Union import dace @@ -19,6 +20,12 @@ from dace.sdfg import SDFG, SDFGState, nodes +__all__ = [ + "SerialMapPromoter", + "SerialMapPromoterGPU", +] + + @properties.make_properties class BaseMapPromoter(transformation.SingleStateTransformation): """Base transformation to add certain missing dimension to a map. @@ -300,3 +307,80 @@ def source_map( # The second map will be promoted, so the first is used as source return state.entry_node(self.exit_first_map) + + +@properties.make_properties +class SerialMapPromoterGPU(transformation.SingleStateTransformation): + """Serial Map promoter for empty Maps in case of trivial Maps. + + In CPU mode a Tasklet can be outside of a map, however, this is not + possible in CPU mode. For this reason DaCe wraps every such Tasklet + in a trivial Map. + This function will look for such Maps and promote them, such that they + can be fused with downstream maps. + + Note: + This transformation must be run after the GPU Transformation. + """ + + # Pattern Matching + map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) + access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the promotion is possible. + + The function tests: + - If the top map is a trivial map. + - If a valid partition exists that can be fused at all. + """ + from .map_seriall_fusion import SerialMapFusion + + map_exit_1: nodes.MapExit = self.map_exit1 + map_1: nodes.Map = map_exit_1.map + map_entry_2: nodes.MapEntry = self.map_entry2 + + # Check if the first map is trivial. + if len(map_1.params) != 1: + return False + if map_1.range.num_elements() != 1: + return False + + # Check if the partition exists, if not promotion to fusing is pointless. + # TODO(phimuell): Find the proper way of doing it. + serial_fuser = SerialMapFusion() + output_partition = serial_fuser.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map Promoting. + + The function essentially copies the parameters and the ranges from the + bottom map to the top one. + """ + map_1: nodes.Map = self.map_exit1.map + map_2: nodes.Map = self.map_entry2.map + + map_1.params = copy.deepcopy(map_2.params) + map_1.range = copy.deepcopy(map_2.range) + + return From 090f08d0972f10f8b223946305ffc56fe41a8ba1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 14:02:09 +0200 Subject: [PATCH 163/235] Added a function to turn an SDFG into one that runs on GPU. Essentially it is a modification of the GPU translation that ships with DaCe. --- .../transformations/__init__.py | 3 +- .../transformations/auto_opt.py | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index a84d4bcfc8..0cc65c4f8b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify +from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_gpu_transformation, gt_simplify from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter, SerialMapPromoterGPU from .map_seriall_fusion import SerialMapFusion @@ -22,6 +22,7 @@ "dace_auto_optimize", "gt_auto_optimize", "gt_simplify", + "gt_gpu_transformation", "SerialMapFusion", "SerialMapPromoter", "SerialMapPromoterGPU", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 24f31db939..263987801f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -88,6 +88,49 @@ def gt_simplify( ).apply_pass(sdfg, {}) +def gt_gpu_transformation( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform an SDFG into an GPU SDFG. + + The transformations are done in place. + The function will roughly do the same: + - Move all arrays used as input to the GPU. + - Apply the standard DaCe GPU transformation. + - Run `gt_simplify()` (recommended by the DaCe documentation). + - Try to promote trivial maps. + """ + + # Turn all global arrays (which we identify as input) into GPU memory. + # This way the GPU transformation will not create this copying stuff. + for desc in sdfg.arrays.values(): + if desc.transient: + continue + if not isinstance(desc, dace.data.Array): + continue + desc.storage = dace.dtypes.StorageType.GPU_Global + + # Now turn it into a GPU SDFG + sdfg.apply_gpu_transformations( + validate=validate, + validate_all=validate_all, + simplify=False, + ) + + # The documentation recommend to run simplify afterwards + gtx_transformations.gt_simplify(sdfg) + + # Start to promote the maps. + sdfg.apply_transformations_repeated( + [gtx_transformations.SerialMapPromoterGPU()], + validate=validate, + validate_all=validate_all, + ) + return sdfg + + def gt_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, From 04dd63a8a266f6ad314597a336840d819ebd05dc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 17 Jul 2024 14:03:31 +0200 Subject: [PATCH 164/235] Updated the auto optimizer to handle GPU cases. --- .../transformations/auto_opt.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 263987801f..9bc9dc670b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -135,6 +135,9 @@ def gt_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, leading_dim: Optional[gtx_common.Dimension] = None, + gpu: bool = False, + validate: bool = True, + validate_all: bool = False, **kwargs: Any, ) -> dace.SDFG: """Performs GT4Py specific optimizations in place. @@ -142,6 +145,8 @@ def gt_auto_optimize( Args: sdfg: The SDFG that should ve optimized in place. device: The device for which we should optimize. + leading_dim: Leading dimension, where the stride is expected to be 1. + gpu: Optimize for GPU. """ with dace.config.temporary_config(): @@ -153,6 +158,9 @@ def gt_auto_optimize( # Compute the SDFG to see if something has changed. sdfg_hash = sdfg.hash_sdfg() + # Have GPU transformations been performed. + have_gpu_transformations_been_run = False + for _ in range(100): # Due to the structure of the generated SDFG getting rid of Maps, # i.e. fusing them, is the best we can currently do. @@ -160,12 +168,14 @@ def gt_auto_optimize( sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) else: xform = gtx_transformations.SerialMapFusion() - sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + sdfg.apply_transformations_repeated( + [xform], validate=validate, validate_all=validate_all + ) sdfg.apply_transformations_repeated( [gtx_transformations.SerialMapPromoter(promote_horizontal=False)], - validate=True, - validate_all=True, + validate=validate, + validate_all=validate_all, ) # Maybe running the fusion has opened more opportunities. @@ -176,8 +186,12 @@ def gt_auto_optimize( sdfg_hash = sdfg.hash_sdfg() if old_sdfg_hash == sdfg_hash: + if gpu and (not have_gpu_transformations_been_run): + gt_gpu_transformation(sdfg, validate=validate, validate_all=validate_all) + have_gpu_transformations_been_run = True + sdfg_hash = sdfg.hash_sdfg() + continue break - else: raise RuntimeWarning("Optimization of the SDFG did not converged.") From bb34f44877824a4ca4799367def532fd82e50f42 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 18 Jul 2024 13:26:46 +0200 Subject: [PATCH 165/235] Reorganized the GPU stuff. There is now also a transformation to set the block size it is not perfect but it should remove some dace warnings. --- .../transformations/__init__.py | 12 +- .../transformations/auto_opt.py | 58 +--- .../transformations/gpu_utils.py | 276 ++++++++++++++++++ .../transformations/map_promoter.py | 79 ----- 4 files changed, 299 insertions(+), 126 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 0cc65c4f8b..361d0ffb7f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,9 +12,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_gpu_transformation, gt_simplify +from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify +from .gpu_utils import ( + GPUSetBlockSize, + SerialMapPromoterGPU, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter, SerialMapPromoterGPU +from .map_promoter import SerialMapPromoter from .map_seriall_fusion import SerialMapFusion @@ -23,8 +29,10 @@ "gt_auto_optimize", "gt_simplify", "gt_gpu_transformation", + "gt_set_gpu_blocksize", "SerialMapFusion", "SerialMapPromoter", "SerialMapPromoterGPU", "MapIterationOrder", + "GPUSetBlockSize", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 9bc9dc670b..dbf0dffb27 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -14,7 +14,7 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Optional +from typing import Any, Optional, Sequence import dace from dace.transformation import dataflow as dace_dataflow @@ -88,54 +88,12 @@ def gt_simplify( ).apply_pass(sdfg, {}) -def gt_gpu_transformation( - sdfg: dace.SDFG, - validate: bool = True, - validate_all: bool = False, -) -> dace.SDFG: - """Transform an SDFG into an GPU SDFG. - - The transformations are done in place. - The function will roughly do the same: - - Move all arrays used as input to the GPU. - - Apply the standard DaCe GPU transformation. - - Run `gt_simplify()` (recommended by the DaCe documentation). - - Try to promote trivial maps. - """ - - # Turn all global arrays (which we identify as input) into GPU memory. - # This way the GPU transformation will not create this copying stuff. - for desc in sdfg.arrays.values(): - if desc.transient: - continue - if not isinstance(desc, dace.data.Array): - continue - desc.storage = dace.dtypes.StorageType.GPU_Global - - # Now turn it into a GPU SDFG - sdfg.apply_gpu_transformations( - validate=validate, - validate_all=validate_all, - simplify=False, - ) - - # The documentation recommend to run simplify afterwards - gtx_transformations.gt_simplify(sdfg) - - # Start to promote the maps. - sdfg.apply_transformations_repeated( - [gtx_transformations.SerialMapPromoterGPU()], - validate=validate, - validate_all=validate_all, - ) - return sdfg - - def gt_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, leading_dim: Optional[gtx_common.Dimension] = None, gpu: bool = False, + gpu_block_size: Optional[Sequence[int | str] | str] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -147,6 +105,7 @@ def gt_auto_optimize( device: The device for which we should optimize. leading_dim: Leading dimension, where the stride is expected to be 1. gpu: Optimize for GPU. + gpu_block_size: The block size that should be used for the GPU. """ with dace.config.temporary_config(): @@ -187,7 +146,12 @@ def gt_auto_optimize( if old_sdfg_hash == sdfg_hash: if gpu and (not have_gpu_transformations_been_run): - gt_gpu_transformation(sdfg, validate=validate, validate_all=validate_all) + gtx_transformations.gt_gpu_transformation( + sdfg, + validate=validate, + validate_all=validate_all, + gpu_block_size=None, # Explicitly not set here. + ) have_gpu_transformations_been_run = True sdfg_hash = sdfg.hash_sdfg() continue @@ -204,6 +168,10 @@ def gt_auto_optimize( ) ) + # After everything we set the GPU block size. + if gpu_block_size is not None: + gtx_transformations.gt_set_gpu_blocksize(sdfg, gpu_block_size) + # These are the part that we copy from DaCe built in auto optimization. dace_aoptimize.set_fast_implementations(sdfg, device) dace_aoptimize.move_small_arrays_to_stack(sdfg) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py new file mode 100644 index 0000000000..dab7e78f70 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -0,0 +1,276 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Functions for turning an SDFG into a GPU SDFG.""" + +import copy +from typing import Any, Optional, Sequence, Union + +import dace +import numpy as np +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +__all__ = [ + "SerialMapPromoterGPU", + "GPUSetBlockSize", + "gt_gpu_transformation", + "gt_set_gpu_blocksize", +] + + +def gt_gpu_transformation( + sdfg: dace.SDFG, + promote_serial_maps: bool = True, + gpu_block_size: Optional[Sequence[int | str] | str] = None, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform an SDFG into an GPU SDFG. + + The transformations are done in place. + The function will roughly do the same: + - Move all arrays used as input to the GPU. + - Apply the standard DaCe GPU transformation. + - Run `gt_simplify()` (recommended by the DaCe documentation). + - Try to promote trivial maps. + - If given set the GPU block size. + """ + + # Turn all global arrays (which we identify as input) into GPU memory. + # This way the GPU transformation will not create this copying stuff. + for desc in sdfg.arrays.values(): + if desc.transient: + continue + if not isinstance(desc, dace.data.Array): + continue + desc.storage = dace.dtypes.StorageType.GPU_Global + + # Now turn it into a GPU SDFG + sdfg.apply_gpu_transformations( + validate=validate, + validate_all=validate_all, + simplify=False, + ) + + # The documentation recommend to run simplify afterwards + gtx_transformations.gt_simplify(sdfg) + + # Start to promote the maps. + if promote_serial_maps: + sdfg.apply_transformations_repeated( + [gtx_transformations.SerialMapPromoterGPU()], + validate=validate, + validate_all=validate_all, + ) + + # Set the GPU block size if it is known. + if gpu_block_size is not None: + gt_set_gpu_blocksize(sdfg, gpu_block_size) + + return sdfg + + +def gt_set_gpu_blocksize( + sdfg: dace.SDFG, + gpu_block_size: Optional[Sequence[int | str] | str], +) -> Any: + """Set the block sizes of GPU Maps. + + Args: + sdfg: The SDFG to process. + gpu_block_size: The block size to use. + """ + return sdfg.apply_transformations_once_everywhere([GPUSetBlockSize(block_size=gpu_block_size)]) + + +def _gpu_block_parser( + self: "GPUSetBlockSize", + val: Any, +) -> None: + """Used by the setter ob `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, tuple): + pass + elif isinstance(val, list): + val = tuple(val) + elif isinstance(val, str): + val = tuple(x.replace(" ", "") for x in val.split(",")) + if len(val) == 1: + val = (*val, 1, 1) + elif len(val) == 2: + val = (*val, 1) + elif len(val) != 3: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + assert all(isinstance(x, (str, int, np.integer)) for x in val) + self._block_size = [int(x) for x in val] + + +def _gpu_block_getter( + self: "GPUSetBlockSize", +) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + assert isinstance(self._block_size, tuple) and len(self.block_size) == 3 + assert all(isinstance(x, int) for x in self._block_size) + return self._block_size + + +@properties.make_properties +class GPUSetBlockSize(transformation.SingleStateTransformation): + """Sets the GPU block size on GPU Maps. + + Todo: + Depending on the number of dimensions of a map, there should be different sources. + """ + + block_size = properties.Property( + dtype=None, + allow_none=True, + setter=_gpu_block_parser, + getter=_gpu_block_getter, + desc="Size of the block size a GPU Map should have.", + ) + + map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + def __int__( + self, + block_size: Sequence[int | str] | str, + ) -> None: + self.block_size = block_size + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the block size can be set. + + The function tests: + - If the block size of the map is already set. + - If the map is at global scope. + - If if the schedule of the map is correct. + """ + + scope = graph.scope_dict() + if scope[self.map_entry] is not None: + return False + if self.map_entry.map.schedule not in dace.dtypes.GPU_SCHEDULES: + return False + if self.map_entry.map.gpu_block_size is not None: + return False + return True + + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: + """Sets the block size.""" + self.map_entry.map.gpu_block_size = self.block_size + + +@properties.make_properties +class SerialMapPromoterGPU(transformation.SingleStateTransformation): + """Serial Map promoter for empty Maps in case of trivial Maps. + + In CPU mode a Tasklet can be outside of a map, however, this is not + possible in CPU mode. For this reason DaCe wraps every such Tasklet + in a trivial Map. + This function will look for such Maps and promote them, such that they + can be fused with downstream maps. + + Note: + This transformation must be run after the GPU Transformation. + """ + + # Pattern Matching + map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) + access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the promotion is possible. + + The function tests: + - If the top map is a trivial map. + - If a valid partition exists that can be fused at all. + """ + from .map_seriall_fusion import SerialMapFusion + + map_exit_1: nodes.MapExit = self.map_exit1 + map_1: nodes.Map = map_exit_1.map + map_entry_2: nodes.MapEntry = self.map_entry2 + + # Check if the first map is trivial. + if len(map_1.params) != 1: + return False + if map_1.range.num_elements() != 1: + return False + + # Check if it is a GPU map + if map_1.schedule not in [ + dace.dtypes.ScheduleType.GPU_Global, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Check if the partition exists, if not promotion to fusing is pointless. + # TODO(phimuell): Find the proper way of doing it. + serial_fuser = SerialMapFusion(only_toplevel_maps=True) + output_partition = serial_fuser.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map Promoting. + + The function essentially copies the parameters and the ranges from the + bottom map to the top one. + """ + map_1: nodes.Map = self.map_exit1.map + map_2: nodes.Map = self.map_entry2.map + + map_1.params = copy.deepcopy(map_2.params) + map_1.range = copy.deepcopy(map_2.range) + + return diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 51d272b77e..b63eda5445 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -12,7 +12,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import copy from typing import Any, Mapping, Optional, Sequence, Union import dace @@ -22,7 +21,6 @@ __all__ = [ "SerialMapPromoter", - "SerialMapPromoterGPU", ] @@ -307,80 +305,3 @@ def source_map( # The second map will be promoted, so the first is used as source return state.entry_node(self.exit_first_map) - - -@properties.make_properties -class SerialMapPromoterGPU(transformation.SingleStateTransformation): - """Serial Map promoter for empty Maps in case of trivial Maps. - - In CPU mode a Tasklet can be outside of a map, however, this is not - possible in CPU mode. For this reason DaCe wraps every such Tasklet - in a trivial Map. - This function will look for such Maps and promote them, such that they - can be fused with downstream maps. - - Note: - This transformation must be run after the GPU Transformation. - """ - - # Pattern Matching - map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) - access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) - - @classmethod - def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - - def can_be_applied( - self, - graph: Union[SDFGState, SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the promotion is possible. - - The function tests: - - If the top map is a trivial map. - - If a valid partition exists that can be fused at all. - """ - from .map_seriall_fusion import SerialMapFusion - - map_exit_1: nodes.MapExit = self.map_exit1 - map_1: nodes.Map = map_exit_1.map - map_entry_2: nodes.MapEntry = self.map_entry2 - - # Check if the first map is trivial. - if len(map_1.params) != 1: - return False - if map_1.range.num_elements() != 1: - return False - - # Check if the partition exists, if not promotion to fusing is pointless. - # TODO(phimuell): Find the proper way of doing it. - serial_fuser = SerialMapFusion() - output_partition = serial_fuser.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - if output_partition is None: - return False - - return True - - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: - """Performs the Map Promoting. - - The function essentially copies the parameters and the ranges from the - bottom map to the top one. - """ - map_1: nodes.Map = self.map_exit1.map - map_2: nodes.Map = self.map_entry2.map - - map_1.params = copy.deepcopy(map_2.params) - map_1.range = copy.deepcopy(map_2.range) - - return From 88f52450a9eb3d8396f4b74703e6264e9b7c1ed9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Jul 2024 11:08:43 +0200 Subject: [PATCH 166/235] There is now a k blocking transformation. However it is not yet tested. --- .../transformations/__init__.py | 2 + .../transformations/k_blocking.py | 373 ++++++++++++++++++ 2 files changed, 375 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 361d0ffb7f..e297cd2237 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -19,6 +19,7 @@ gt_gpu_transformation, gt_set_gpu_blocksize, ) +from .k_blocking import KBlocking from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter from .map_seriall_fusion import SerialMapFusion @@ -35,4 +36,5 @@ "SerialMapPromoterGPU", "MapIterationOrder", "GPUSetBlockSize", + "KBlocking", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py new file mode 100644 index 0000000000..f6abc56702 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -0,0 +1,373 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import copy +import functools +from typing import Any, Union + +import dace +from dace import properties, subsets, transformation +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes +from dace.transformation import helpers + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util + + +@properties.make_properties +class KBlocking(transformation.SingleStateTransformation): + """Performs K blocking. + + This transformation takes a multidimensional map and performs k blocking on + one particular dimension, which is identified by `block_dim`, which we also + call `k`. + All dimensions except `k` are unaffected and the transformation replaces it + with `kk` and the range `0:N:B`, where `N` is the original end of the + transformation and `B` is the block size, passed as `blocking_size`. + The transformation will then add an inner sequential map with one + dimension `k = kk:(kk+B)`. + + Furthermore, the map will inspect the neighbours of the old, or outer map. + If the node does not depend on the blocked dimension, the node will be put + between the two maps, thus its content will only be computed once. + """ + + blocking_size = properties.Property( + dtype=int, + allow_none=True, + desc="Size of the inner k Block.", + ) + block_dim = properties.Property( + dtype=str, + allow_none=True, + desc="Which dimension should be blocked.", + ) + + map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + blocking_size: int, + block_dim: Union[gtx_common.Dimension, str], + ) -> None: + super().__init__() + self.blocking_size = blocking_size + if isinstance(block_dim, str): + pass + elif isinstance(block_dim, gtx_common.Dimension): + block_dim = dace_fieldview_util.get_map_variable(block_dim) + self.block_dim = block_dim + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the map can be blocked. + + For this the map: + - Must contain the block dimension. + - Must not be serial. + """ + map_entry: nodes.MapEntry = self.map_entry + map_params: list[str] = map_entry.map.params + map_range: subsets.Range = map_entry.map.range + block_var: str = self.block_dim + + scope = graph.scope_dict() + if scope[map_entry] is not None: + return False + if block_var not in map_entry.map.params: + return False + if map_entry.map.schedule == dace.dtypes.ScheduleType.Sequential: + return False + if self.partition_map_output(map_entry, block_var, graph, sdfg) is None: + return False + if map_range[map_params.index(block_var)][2] != 1: + return False + + return True + + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: + """Performs the blocking transformation.""" + outer_entry: nodes.MapEntry = self.map_entry + outer_exit: nodes.MapExit = graph.exit_node(outer_entry) + outer_map: nodes.Map = outer_entry.map + map_range: subsets.Range = outer_entry.map.range + map_params: list[str] = outer_entry.map.params + + # This is the name of the iterator we coarsen + block_var: str = self.block_dim + block_idx = map_params.index(block_var) + + # This is the name of the iterator that we use in the outer map for the + # blocked dimension + coarse_block_var = "__coarse_" + block_var + + # Now compute the partitions of the nodes. + independent_nodes, dependent_nodes = self.partition_map_output( # type: ignore[misc] # Guaranteed to be not `None`. + outer_entry, block_var, graph, sdfg + ) + + # Now generate the sequential inner map + rng_start = map_range[block_idx][0] + rng_stop = map_range[block_idx][1] + inner_label = f"inner_{outer_map.label}" + inner_range = { + block_var: subsets.Range.from_string( + f"({coarse_block_var} * {self.blocking_size} + {rng_start}):min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + ) + } + inner_entry, inner_exit = graph.add_map( + name=inner_label, + ndrange=inner_range, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + # Now we modify the properties of the outer map. + coarse_block_range = subsets.Range.from_string( + f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + ).ranges[0] + outer_map.params[block_idx] = coarse_block_var + outer_map.range[block_idx] = coarse_block_range + + # Contains the independent nodes that are already relocated. + relocated_nodes: set[nodes.Node] = set() + + # Now we iterate over all the output edges of the outer map and rewire them. + # Note that this only handles the map entry. + for out_edge in list(graph.out_edges(outer_entry)): + edge_dst: nodes.Node = out_edge.dst + edge_conn: str = out_edge.src_conn[4:] + + if edge_dst in dependent_nodes: + # This is the simple case as we just ave to rewire the edge + # and make a connection between the outer and inner map. + helpers.redirect_edge( + state=graph, + edge=out_edge, + new_src=inner_entry, + new_src_conn="OUT_" + edge_conn, + ) + graph.add_edge( + outer_entry, + "OUT_" + edge_conn, + inner_entry, + "IN_" + edge_conn, + copy.deepcopy(out_edge.data), + ) + inner_entry.add_in_connector("IN_" + edge_conn) + inner_entry.add_out_connector("OUT_" + edge_conn) + continue + + elif edge_dst in relocated_nodes: + # See `else` case + continue + + else: + # Relocate the node and make the reconnection. + # Different from the dependent case we will handle the node fully, + # i.e. all of its edges will be processed in one go. + relocated_nodes.add(edge_dst) + + # This is the node serving as the storage to store the independent + # data, and is used within the inner loop. + # This prevents the reloading of data. + assert graph.out_degree(edge_dst) == 1 + caching_node: nodes.AccessNode = next(iter(graph.out_edges(edge_dst))) + + for consumer_edge in list(graph.out_edges(caching_node)): + new_map_conn = inner_exit.next_connector() + helpers.redirect_edge( + state=graph, + edge=consumer_edge, + new_dst=inner_entry, + new_dst_conn="IN_" + new_map_conn, + ) + graph.add_edge( + inner_entry, + "OUT_" + new_map_conn, + consumer_edge.dst, + consumer_edge.dst_conn, + copy.deepcopy(consumer_edge.data), + ) + inner_entry.add_in_connector("IN_" + new_map_conn) + inner_entry.add_out_connector("OUT_" + new_map_conn) + continue + + # Now we have to handle the output of the map. + # There is not many to do just reconnect some edges. + for out_edge in list(graph.in_edges(outer_exit)): + edge_conn = out_edge.dst_conn[3:] + helpers.redirect_edge( + state=graph, + edge=out_edge, + new_dst=inner_exit, + new_dst_conn="IN_" + edge_conn, + ) + graph.add_edge( + inner_exit, + "OUT_" + edge_conn, + outer_exit, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + inner_exit.add_in_connector("IN_" + edge_conn) + inner_exit.add_out_connector("OUT_" + edge_conn) + + # TODO(phimuell): Use a less expensive method. + dace.sdfg.propagation.propagate_memlets_state(sdfg, graph) + + def partition_map_output( + self, + map_entry: nodes.MapEntry, + block_param: str, + state: SDFGState, + sdfg: SDFG, + ) -> tuple[set[nodes.Node], set[nodes.Node]] | None: + """Partition the outputs + + This function computes the partition of the intermediate outputs of the map. + It will compute two set: + - The independent outputs `\mathcal{I}`: + These are output nodes, whose output does not depend on the blocked + dimension. These nodes will be relocated between the outer and + inner map. + - The dependent output `\mathcal{D}`: + These are the output nodes, whose output depend on the blocked dimension. + Thus they will not be relocated between the two maps, but remain in the + inner most scope. + + In case the function fails to compute the partition `None` is returned. + + Args: + map_entry: The map entry node. + block_param: The Map variable that should be blocked. + state: The state on which we operate. + sdfg: The SDFG in which we operate on. + + Note: + - Currently this function only considers the input Memlets and the + `used_symbol` properties of a Tasklet. + - Furthermore only the first level is inspected. + """ + block_independent: set[nodes.Node] = set() # `\mathcal{I}` + block_dependent: set[nodes.Node] = set() # `\mathcal{D}` + + # Find all nodes that are adjacent to the map entry. + nodes_to_partition: set[nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} + + # Now we examine every node and assign them to a set. + # Note that this is only tentative and we will later inspect the + # outputs of the independent node and reevaluate the classification. + for node in nodes_to_partition: + # Filter out all nodes that we can not (yet) handle. + if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): + return None + + # Check if we have a strange Tasklet or if it uses the symbol inside it. + if isinstance(node, nodes.Tasklet): + if node.side_effects: + return None + if block_param in node.free_symbols: + block_dependent.add(node) + continue + + # Only one output is allowed + # Might be less important for Tasklets but for AccessNodes. + # TODO(phimuell): Lift this restriction. + if state.out_degree(node) != 1: + block_dependent.add(node) + continue + + # Now we have to understand how the node generates its information. + # for this we have to look at all the edges that feed information to it. + edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges(node)) + + # If the node gets information from other nodes than the map entry + # we classify it as a dependent node, although there can be situations + # were it could still be an independent node, but figuring this out + # is too complicated. + if any(edge.src is not map_entry for edge in edges): + block_dependent.add(node) + continue + + # If all edges are empty, i.e. they are only needed to keep the + # node inside the scope, consider it as independent. + if all(edge.data.is_empty() for edge in edges): + block_independent.add(node) + continue + + # Now we have to look at the edges individually. + # If this loop ends normally, i.e. it goes into its `else` + # clause then we classify the node as independent. + for edge in edges: + memlet: dace.Memlet = edge.data + src_subset: subsets.Subset = memlet.src_subset + dst_subset: subsets.Subset | None = memlet.dst_subset + edge_desc: dace.data.Data = sdfg.arrays[memlet.data] + edge_desc_size = functools.reduce(lambda a, b: a * b, edge_desc.shape) + + if memlet.is_empty(): + # Empty Memlets do not impose any restrictions. + continue + if src_subset.num_elements() == edge_desc_size: + # The whole source array is consumed, which is not a problem. + continue + + # Now we have to look at the source and destination set of the Memlet. + subsets_to_inspect: list[subsets.Subset] = [src_subset] + if dst_subset is not None: + subsets_to_inspect.append(dst_subset) + + # If a subset needs the block variable then the node is not + # independent from the block variable. + if any(block_param in subset.free_symbols for subset in subsets_to_inspect): + break + else: + # The loop ended normally, thus we did not found anything that made us + # believe that the node is _not_ an independent node. We will later + # also inspect the output, which might reclassify the node + block_independent.add(node) + + # If the node is not independent then it must be dependent, my dear Watson. + if node not in block_independent: + block_dependent.add(node) + + # We now make a last screening of the independent nodes. + for independent_node in list(block_independent): + for out_edge in state.out_edges(independent_node): + if ( + (not isinstance(out_edge.dst, nodes.AccessNode)) + or (state.in_degree(out_edge.dst) != 1) + or (sdfg.desc(out_edge.dst).lifetime != dace.dtypes.AllocationLifetime.Scope) + ): + block_independent.discard(independent_node) + block_dependent.add(independent_node) + break + + assert not block_dependent.intersection(block_independent) + assert (len(block_independent) + len(block_dependent)) == len(nodes_to_partition) + + return (block_independent, block_dependent) From 73a01c1b0d3fde5df96639cacca273a87f61724b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Jul 2024 14:09:46 +0200 Subject: [PATCH 167/235] Made some fixes to the k blocking stuff. --- .../transformations/k_blocking.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index f6abc56702..6764ab113a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -195,7 +195,12 @@ def apply( # data, and is used within the inner loop. # This prevents the reloading of data. assert graph.out_degree(edge_dst) == 1 - caching_node: nodes.AccessNode = next(iter(graph.out_edges(edge_dst))) + if isinstance(edge_dst, nodes.AccessNode): + caching_node: nodes.AccessNode = edge_dst + else: + caching_node = next(iter(graph.out_edges(edge_dst))).dst + assert graph.in_degree(caching_node) == 1 + assert isinstance(caching_node, nodes.AccessNode) for consumer_edge in list(graph.out_edges(caching_node)): new_map_conn = inner_exit.next_connector() @@ -324,7 +329,7 @@ def partition_map_output( # clause then we classify the node as independent. for edge in edges: memlet: dace.Memlet = edge.data - src_subset: subsets.Subset = memlet.src_subset + src_subset: subsets.Subset | None = memlet.src_subset dst_subset: subsets.Subset | None = memlet.dst_subset edge_desc: dace.data.Data = sdfg.arrays[memlet.data] edge_desc_size = functools.reduce(lambda a, b: a * b, edge_desc.shape) @@ -332,14 +337,16 @@ def partition_map_output( if memlet.is_empty(): # Empty Memlets do not impose any restrictions. continue - if src_subset.num_elements() == edge_desc_size: + if memlet.num_elements() == edge_desc_size: # The whole source array is consumed, which is not a problem. continue # Now we have to look at the source and destination set of the Memlet. - subsets_to_inspect: list[subsets.Subset] = [src_subset] + subsets_to_inspect: list[subsets.Subset] = [] if dst_subset is not None: subsets_to_inspect.append(dst_subset) + if src_subset is not None: + subsets_to_inspect.append(src_subset) # If a subset needs the block variable then the node is not # independent from the block variable. @@ -357,11 +364,16 @@ def partition_map_output( # We now make a last screening of the independent nodes. for independent_node in list(block_independent): + if isinstance(independent_node, nodes.AccessNode): + if state.in_degree(independent_node) != 1: + block_independent.discard(independent_node) + block_dependent.add(independent_node) + continue for out_edge in state.out_edges(independent_node): if ( (not isinstance(out_edge.dst, nodes.AccessNode)) or (state.in_degree(out_edge.dst) != 1) - or (sdfg.desc(out_edge.dst).lifetime != dace.dtypes.AllocationLifetime.Scope) + or (out_edge.dst.desc(sdfg).lifetime != dace.dtypes.AllocationLifetime.Scope) ): block_independent.discard(independent_node) block_dependent.add(independent_node) From dd1242cbab10c4feff043f6f8f3b572285689ba1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Jul 2024 14:49:54 +0200 Subject: [PATCH 168/235] Small fix. --- .../runners/dace_fieldview/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index dab7e78f70..e312423e47 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -242,7 +242,7 @@ def can_be_applied( # Check if it is a GPU map if map_1.schedule not in [ - dace.dtypes.ScheduleType.GPU_Global, + dace.dtypes.ScheduleType.GPU_Device, dace.dtypes.ScheduleType.GPU_Default, ]: return False From f3798f3d238c3cb5cc7fc33992a74bb81bf1e274 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Jul 2024 15:50:29 +0200 Subject: [PATCH 169/235] Fixed an error. --- .../transformations/k_blocking.py | 59 ++++++++++++++----- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 6764ab113a..1f8ed3d801 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -159,24 +159,45 @@ def apply( # Note that this only handles the map entry. for out_edge in list(graph.out_edges(outer_entry)): edge_dst: nodes.Node = out_edge.dst - edge_conn: str = out_edge.src_conn[4:] if edge_dst in dependent_nodes: # This is the simple case as we just ave to rewire the edge # and make a connection between the outer and inner map. + assert not out_edge.data.is_empty() + edge_conn: str = out_edge.src_conn[4:] + + # Must be before the handling of the modification below + # Note that this will remove the edge from the SDFG. helpers.redirect_edge( state=graph, edge=out_edge, new_src=inner_entry, new_src_conn="OUT_" + edge_conn, ) - graph.add_edge( - outer_entry, - "OUT_" + edge_conn, - inner_entry, - "IN_" + edge_conn, - copy.deepcopy(out_edge.data), - ) + + # In a valid SDFG one one edge can go to an input connector of + # a map (this is their definition), thus we have to adhere to + # this definition as well. + if "IN_" + edge_conn in inner_entry.in_connectors: + # We have found this edge multiple times already. + # To ensure that there is no error, we will create a new + # Memlet that reads the whole array. + piping_edge = next(graph.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) + data_name = piping_edge.data.data + piping_edge.data = dace.Memlet.from_array( + data_name, sdfg.arrays[data_name], piping_edge.data.wcr + ) + + else: + # This is the first time we found this connection. + # so we just create the edge. + graph.add_edge( + outer_entry, + "OUT_" + edge_conn, + inner_entry, + "IN_" + edge_conn, + copy.deepcopy(out_edge.data), + ) inner_entry.add_in_connector("IN_" + edge_conn) inner_entry.add_out_connector("OUT_" + edge_conn) continue @@ -203,7 +224,7 @@ def apply( assert isinstance(caching_node, nodes.AccessNode) for consumer_edge in list(graph.out_edges(caching_node)): - new_map_conn = inner_exit.next_connector() + new_map_conn = inner_entry.next_connector() helpers.redirect_edge( state=graph, edge=consumer_edge, @@ -310,6 +331,20 @@ def partition_map_output( # for this we have to look at all the edges that feed information to it. edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges(node)) + # If all edges are empty, i.e. they are only needed to keep the + # node inside the scope, consider it as independent. If they are + # tied to different scopes, refuse to work. + if all(edge.data.is_empty() for edge in edges): + if not all(edge.src is map_entry for edge in edges): + return None + block_independent.add(node) + continue + + # Currently we do not allow that a node has a mix of empty and non + # empty Memlets, is all or nothing. + if any(edge.data.is_empty() for edge in edges): + return None + # If the node gets information from other nodes than the map entry # we classify it as a dependent node, although there can be situations # were it could still be an independent node, but figuring this out @@ -318,12 +353,6 @@ def partition_map_output( block_dependent.add(node) continue - # If all edges are empty, i.e. they are only needed to keep the - # node inside the scope, consider it as independent. - if all(edge.data.is_empty() for edge in edges): - block_independent.add(node) - continue - # Now we have to look at the edges individually. # If this loop ends normally, i.e. it goes into its `else` # clause then we classify the node as independent. From 863bd5ffec2ad6ec7ea7e5810921c804291964e6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Jul 2024 15:55:27 +0200 Subject: [PATCH 170/235] Now auto optimization also does blocking. --- .../dace_fieldview/transformations/auto_opt.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index dbf0dffb27..9c9d8b2d50 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -96,6 +96,8 @@ def gt_auto_optimize( gpu_block_size: Optional[Sequence[int | str] | str] = None, validate: bool = True, validate_all: bool = False, + block_dim: Optional[gtx_common.Dimension] = None, + blocking_size: int = 10, **kwargs: Any, ) -> dace.SDFG: """Performs GT4Py specific optimizations in place. @@ -172,6 +174,15 @@ def gt_auto_optimize( if gpu_block_size is not None: gtx_transformations.gt_set_gpu_blocksize(sdfg, gpu_block_size) + if gpu and (block_dim is not None): + sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking( + blocking_size=10, + block_dim=block_dim, + ), + validate=True, + ) + # These are the part that we copy from DaCe built in auto optimization. dace_aoptimize.set_fast_implementations(sdfg, device) dace_aoptimize.move_small_arrays_to_stack(sdfg) From ea0da2b6285551396f226dd7fa4a567b1254c13a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 22 Jul 2024 09:17:48 +0200 Subject: [PATCH 171/235] Made some fixes, but it still does not work. --- .../dace_fieldview/transformations/gpu_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index e312423e47..62164d673a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -112,6 +112,8 @@ def _gpu_block_parser( val = tuple(val) elif isinstance(val, str): val = tuple(x.replace(" ", "") for x in val.split(",")) + else: + raise TypeError(f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size.") if len(val) == 1: val = (*val, 1, 1) elif len(val) == 2: @@ -126,9 +128,9 @@ def _gpu_block_getter( self: "GPUSetBlockSize", ) -> tuple[int, int, int]: """Used as getter in the `GPUSetBlockSize.block_size` property.""" - assert isinstance(self._block_size, tuple) and len(self.block_size) == 3 + assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 assert all(isinstance(x, int) for x in self._block_size) - return self._block_size + return tuple(self._block_size) @properties.make_properties @@ -141,7 +143,8 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): block_size = properties.Property( dtype=None, - allow_none=True, + allow_none=False, + default=(32,1,1), setter=_gpu_block_parser, getter=_gpu_block_getter, desc="Size of the block size a GPU Map should have.", @@ -151,9 +154,11 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): def __int__( self, - block_size: Sequence[int | str] | str, + block_size: Sequence[int | str] | str | None = None, ) -> None: - self.block_size = block_size + super().__init__() + if block_size is not None: + self.block_size = block_size @classmethod def expressions(cls) -> Any: From a95bda2c0369f2e05ca38298bee1e37bfa7be820 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 22 Jul 2024 09:49:36 +0200 Subject: [PATCH 172/235] Made some fixes. --- .../runners/dace_fieldview/transformations/gpu_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 62164d673a..a8e7e88bf3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -113,7 +113,9 @@ def _gpu_block_parser( elif isinstance(val, str): val = tuple(x.replace(" ", "") for x in val.split(",")) else: - raise TypeError(f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size.") + raise TypeError( + f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size." + ) if len(val) == 1: val = (*val, 1, 1) elif len(val) == 2: @@ -144,7 +146,7 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): block_size = properties.Property( dtype=None, allow_none=False, - default=(32,1,1), + default=(32, 1, 1), setter=_gpu_block_parser, getter=_gpu_block_getter, desc="Size of the block size a GPU Map should have.", @@ -152,7 +154,7 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): map_entry = transformation.transformation.PatternNode(nodes.MapEntry) - def __int__( + def __init__( self, block_size: Sequence[int | str] | str | None = None, ) -> None: From 18a85608bcaf02755c5ecfd928e1da136c351149 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 22 Jul 2024 14:43:07 +0200 Subject: [PATCH 173/235] If blocking is applied the name of the outer map is now also changed. This is mostly to see a difference in the output of NCU. --- .../runners/dace_fieldview/transformations/k_blocking.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 1f8ed3d801..6fa971ba92 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -41,6 +41,9 @@ class KBlocking(transformation.SingleStateTransformation): Furthermore, the map will inspect the neighbours of the old, or outer map. If the node does not depend on the blocked dimension, the node will be put between the two maps, thus its content will only be computed once. + + The function will also change the name of the outer map, it will append + `_blocked` to it. """ blocking_size = properties.Property( @@ -151,6 +154,7 @@ def apply( ).ranges[0] outer_map.params[block_idx] = coarse_block_var outer_map.range[block_idx] = coarse_block_range + outer_map.label = f"{outer_map.label}_blocked" # Contains the independent nodes that are already relocated. relocated_nodes: set[nodes.Node] = set() From 5d979c93bc16b09bd9527da5aec912e8ca168a46 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Jul 2024 13:36:03 +0200 Subject: [PATCH 174/235] Implemented the possibility to also set the launch bound stuff. --- .../transformations/auto_opt.py | 11 +++- .../transformations/gpu_utils.py | 63 +++++++++++++++++-- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 9c9d8b2d50..9779f92b79 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -94,6 +94,8 @@ def gt_auto_optimize( leading_dim: Optional[gtx_common.Dimension] = None, gpu: bool = False, gpu_block_size: Optional[Sequence[int | str] | str] = None, + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, validate: bool = True, validate_all: bool = False, block_dim: Optional[gtx_common.Dimension] = None, @@ -108,6 +110,8 @@ def gt_auto_optimize( leading_dim: Leading dimension, where the stride is expected to be 1. gpu: Optimize for GPU. gpu_block_size: The block size that should be used for the GPU. + gpu_launch_bounds: The launch bounds to use. + gpu_launch_factor: The launch factor to use. """ with dace.config.temporary_config(): @@ -172,7 +176,12 @@ def gt_auto_optimize( # After everything we set the GPU block size. if gpu_block_size is not None: - gtx_transformations.gt_set_gpu_blocksize(sdfg, gpu_block_size) + gtx_transformations.gt_set_gpu_blocksize( + sdfg=sdfg, + gpu_block_size=gpu_block_size, + gpu_launch_bounds=gpu_launch_bounds, + gpu_launch_factor=gpu_launch_factor, + ) if gpu and (block_dim is not None): sdfg.apply_transformations_repeated( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index a8e7e88bf3..b5594fda16 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -18,7 +18,6 @@ from typing import Any, Optional, Sequence, Union import dace -import numpy as np from dace import properties, transformation from dace.sdfg import SDFG, SDFGState, nodes @@ -39,6 +38,8 @@ def gt_gpu_transformation( sdfg: dace.SDFG, promote_serial_maps: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, validate: bool = True, validate_all: bool = False, ) -> dace.SDFG: @@ -82,7 +83,12 @@ def gt_gpu_transformation( # Set the GPU block size if it is known. if gpu_block_size is not None: - gt_set_gpu_blocksize(sdfg, gpu_block_size) + gt_set_gpu_blocksize( + sdfg=sdfg, + gpu_block_size=gpu_block_size, + gpu_launch_bounds=gpu_launch_bounds, + gpu_launch_factor=gpu_launch_factor, + ) return sdfg @@ -90,14 +96,23 @@ def gt_gpu_transformation( def gt_set_gpu_blocksize( sdfg: dace.SDFG, gpu_block_size: Optional[Sequence[int | str] | str], + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, ) -> Any: """Set the block sizes of GPU Maps. Args: sdfg: The SDFG to process. gpu_block_size: The block size to use. + gpu_launch_bounds: The launch bounds to use. + gpu_launch_factor: The launch factor to use. """ - return sdfg.apply_transformations_once_everywhere([GPUSetBlockSize(block_size=gpu_block_size)]) + xform = GPUSetBlockSize( + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + return sdfg.apply_transformations_once_everywhere([xform]) def _gpu_block_parser( @@ -122,8 +137,13 @@ def _gpu_block_parser( val = (*val, 1) elif len(val) != 3: raise ValueError(f"Can not parse block size '{org_val}': wrong length") - assert all(isinstance(x, (str, int, np.integer)) for x in val) - self._block_size = [int(x) for x in val] + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + self._block_size = val def _gpu_block_getter( @@ -139,6 +159,12 @@ def _gpu_block_getter( class GPUSetBlockSize(transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. + Args: + block_size: The block size that should be used. + launch_bounds: The value for the launch bound that should be used. + launch_factor: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number. + Todo: Depending on the number of dimensions of a map, there should be different sources. """ @@ -152,16 +178,39 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): desc="Size of the block size a GPU Map should have.", ) + launch_bounds = properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property of the map.", + ) + map_entry = transformation.transformation.PatternNode(nodes.MapEntry) def __init__( self, block_size: Sequence[int | str] | str | None = None, + launch_bounds: int | str | None = None, + launch_factor: int | None = None, ) -> None: super().__init__() if block_size is not None: self.block_size = block_size + if launch_factor is not None: + assert launch_bounds is None + self.launch_bounds = str( + int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] + ) + elif launch_bounds is None: + self.launch_bounds = None + elif isinstance(launch_bounds, (str, int)): + self.launch_bounds = str(launch_bounds) + else: + raise TypeError( + f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." + ) + @classmethod def expressions(cls) -> Any: return [dace.sdfg.utils.node_path_graph(cls.map_entry)] @@ -195,8 +244,10 @@ def apply( graph: Union[SDFGState, SDFG], sdfg: SDFG, ) -> None: - """Sets the block size.""" + """Modify the map as requested.""" self.map_entry.map.gpu_block_size = self.block_size + if self.launch_bounds is not None: # Note empty string has a meaning in DaCe + self.map_entry.map.gpu_launch_bounds = self.launch_bounds @properties.make_properties From bcd63d3652d3aa328b50fbf17c985e1168cc9d9a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 23 Jul 2024 15:02:12 +0200 Subject: [PATCH 175/235] Fixed a bug in the auto omptimizer. The blocking was always hardcoded to 10. --- .../runners/dace_fieldview/transformations/auto_opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 9779f92b79..8637f2e21a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -186,7 +186,7 @@ def gt_auto_optimize( if gpu and (block_dim is not None): sdfg.apply_transformations_repeated( gtx_transformations.KBlocking( - blocking_size=10, + blocking_size=blocking_size, block_dim=block_dim, ), validate=True, From c165f9f8881e0fec7830e4563a88ce48cb2d77b4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 26 Jul 2024 09:54:50 +0200 Subject: [PATCH 176/235] Restructured and cleaned up the auto omptimizer routine. --- .../transformations/auto_opt.py | 250 ++++++++++++------ 1 file changed, 172 insertions(+), 78 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 8637f2e21a..d2434e3f97 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -29,32 +29,24 @@ def dace_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, + use_gpu_storage: bool = True, **kwargs: Any, ) -> dace.SDFG: """This is a convenient wrapper arround DaCe's `auto_optimize` function. - By default it uses the `CPU` device type. Furthermore, it will first run the - `{In, Out}LocalStorage` transformations of the SDFG. The reason for this is that - empirical observations have shown, that the current auto optimizer has problems - in certain cases and this should prevent some of them. - Args: sdfg: The SDFG that should be optimized in place. device: the device for which optimizations should be done, defaults to CPU. + use_gpu_storage: Assumes that the SDFG input is already on the GPU. + This parameter is `False` in DaCe but here is changed to `True`. kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. """ - - # Now put output storages everywhere to make auto optimizer less likely to fail. - # sdfg.apply_transformations_repeated([InLocalStorage, OutLocalStorage]) # noqa: ERA001 [commented-out-code] - - # Now the optimization. - sdfg = dace_aoptimize(sdfg, device=device, **kwargs) - - # Now the simplification step. - # This should get rid of some of teh additional transients we have added. - sdfg.simplify() - - return sdfg + return dace_aoptimize.auto_optimize( + sdfg, + device=device, + use_gpu_storage=use_gpu_storage, + **kwargs, + ) def gt_simplify( @@ -77,6 +69,8 @@ def gt_simplify( Note: The reason for this function is that we can influence how simplify works. Since some parts in simplify might break things in the SDFG. + However, currently nothing is customized yet, and the function just calls + the simplification pass directly. """ from dace.transformation.passes.simplify import SimplifyPass @@ -90,83 +84,170 @@ def gt_simplify( def gt_auto_optimize( sdfg: dace.SDFG, - device: dace.DeviceType = dace.DeviceType.CPU, + gpu: bool, leading_dim: Optional[gtx_common.Dimension] = None, - gpu: bool = False, + aggressive_fusion: bool = True, + make_persistent: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, - validate: bool = True, - validate_all: bool = False, block_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + validate: bool = True, + validate_all: bool = False, **kwargs: Any, ) -> dace.SDFG: - """Performs GT4Py specific optimizations in place. + """Performs GT4Py specific optimizations on the SDFG in place. + + The auto optimization works in different phases, that focuses each on + different aspects of the SDFG. The initial SDFG is assumed to have a + very large number of rather simple Maps. + + 1. Some general simplification transformations, beyond classical simplify, + are applied to the SDFG. + 2. In this phase the function tries to reduce the number of maps. This + process mostly relies on the map fusion transformation. If + `aggressive_fusion` is set the function will also promote certain Maps, to + make them fusable. For this it will add dummy dimensions. However, currently + the function will only add horizonal dimensions. + In this phase some optimizations inside the bigger kernels themselves might + be applied as well. + 3. After the function created big kernels it will apply some optimization, + inside the kernels itself. For example fuse maps inside them. + 4. Afterwards it will process the map ranges and iteration order. For this + the function assumes that the dimension indicated by `leading_dim` is the + once with stride one. + 5. If requested the function will now apply blocking, on the dimension indicated + by `leading_dim`. (The reason that it is not done in the kernel optimization + phase is a restriction dictated by the implementation.) + 6. If requested the SDFG will be transformed to GPU. For this the + `gt_gpu_transformation()` function is used, that might apply several other + optimizations. + 7. Afterwards some general transformations to the SDFG are applied. + This includes: + - Use fast implementation for library nodes. + - Move small transients to stack. + - Make transients persistent (if requested). + - Reuse transients. Args: sdfg: The SDFG that should ve optimized in place. - device: The device for which we should optimize. - leading_dim: Leading dimension, where the stride is expected to be 1. - gpu: Optimize for GPU. - gpu_block_size: The block size that should be used for the GPU. - gpu_launch_bounds: The launch bounds to use. - gpu_launch_factor: The launch factor to use. + gpu: Optimize for GPU or CPU. + leading_dim: Leading dimension, indicates where the stride is 1. + aggressive_fusion: Be more aggressive in fusion, will lead to the promotion + of certain maps. + make_persistent: Turn all transients to persistent lifetime, thus they are + allocated over the whole lifetime of the program, even if the kernel exits. + Thus the SDFG can not be called by different threads. + gpu_block_size: The thread block size for maps in GPU mode, currently only + one for all. + block_dim: On which dimension blocking should be applied. + blocking_size: How many elements each block should process. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + + Todo: + - Make sure that `SDFG.simplify()` is not called indirectly, by temporary + overwriting it with `gt_simplify()`. + - Specify arguments to set the size of GPU thread blocks depending on the + dimensions. I.e. be able to use a different size for 1D than 2D Maps. + - Add a parallel version of Map fusion. + - Implements some model to further guide to determine what we want to fuse. + Something along the line "Fuse if operational intensity goes up, but + not if we have too much internal space (register pressure). + - Create a custom array elimination pass that honor rule 1. """ with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) + dace.Config.set("store_history", value=False) + + # TODO(phimuell): Should there be a zeroth phase, in which we generate + # a chanonical form of the SDFG, for example move all local maps + # to internal serial maps, such that they not block fusion? + # in the JaCe prototype we did that. - # Initial cleaning + # Phase 1: Initial Cleanup gt_simplify(sdfg) + sdfg.apply_transformations_once_everywhere( + [ + dace_dataflow.TrivialMapElimination, + dace_dataflow.MapReduceFusion, + dace_dataflow.MapWCRFusion, + ], + validate=validate, + validate_all=validate_all, + ) # Compute the SDFG to see if something has changed. sdfg_hash = sdfg.hash_sdfg() - # Have GPU transformations been performed. - have_gpu_transformations_been_run = False - + # Phase 2: Kernel Creation + # We will now try to reduce the number of kernels and create big one. + # For this we essentially use map fusion. We do this is a loop because + # after a graph modification followed by simplify new fusing opportunities + # might arise. We use the hash of the SDFG to detect if we have reached a + # fix point. for _ in range(100): - # Due to the structure of the generated SDFG getting rid of Maps, - # i.e. fusing them, is the best we can currently do. - if kwargs.get("use_dace_fusion", False): - sdfg.apply_transformations_repeated([dace_dataflow.MapFusion]) - else: - xform = gtx_transformations.SerialMapFusion() - sdfg.apply_transformations_repeated( - [xform], validate=validate, validate_all=validate_all - ) - + # Use map fusion to reduce their number and to create big kernels + # TODO(phimuell): Use a cost measurement to decide if fusion should be done. + # TODO(phimuell): Add parallel fusion transformation. Should it run after + # or with the serial one? sdfg.apply_transformations_repeated( - [gtx_transformations.SerialMapPromoter(promote_horizontal=False)], + gtx_transformations.SerialMapFusion( + only_toplevel_maps=True, + ), validate=validate, validate_all=validate_all, ) - # Maybe running the fusion has opened more opportunities. - gt_simplify(sdfg) + # Now do some cleanup task, that may enable further fusion opportunities. + # Note for performance reasons simplify is deferred. + phase2_cleanup = [] + phase2_cleanup.append(dace_dataflow.TrivialTaskletElimination()) + + # TODO(phimuell): Should we do this all the time or only once? (probably the later) + # TODO(phimuell): More control what we promote. + phase2_cleanup.append( + gtx_transformations.SerialMapPromoter( + promote_horizontal=False, + ) + ) + + sdfg.apply_transformations_once_everywhere( + phase2_cleanup, + validate=validate, + validate_all=validate_all, + ) - # check if something has changed and if so end it here. + # Use the hash to determine if the transformations did modify the SDFG. + # If not we have optimized the SDFG as much as we could, in this phase. old_sdfg_hash = sdfg_hash sdfg_hash = sdfg.hash_sdfg() - if old_sdfg_hash == sdfg_hash: - if gpu and (not have_gpu_transformations_been_run): - gtx_transformations.gt_gpu_transformation( - sdfg, - validate=validate, - validate_all=validate_all, - gpu_block_size=None, # Explicitly not set here. - ) - have_gpu_transformations_been_run = True - sdfg_hash = sdfg.hash_sdfg() - continue break + + # The SDFG was modified by the transformations above. But no + # transformation could be applied, so we will now call simplify + # and start over. + gt_simplify(sdfg) + else: raise RuntimeWarning("Optimization of the SDFG did not converged.") - # After we have optimized the SDFG as good as we can, we will now do some - # lower level optimization. + # Phase 3: Optimizing the kernels themselves. + # Currently this is only applies fusion inside them. + # TODO(phimuell): Improve. + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion( + only_inner_maps=True, + ), + validate=validate, + validate_all=validate_all, + ) + gt_simplify(sdfg) + + # Phase 4: Iteration Space + # This essentially ensures that the stride 1 dimensions are handled + # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.MapIterationOrder( @@ -174,30 +255,43 @@ def gt_auto_optimize( ) ) - # After everything we set the GPU block size. - if gpu_block_size is not None: - gtx_transformations.gt_set_gpu_blocksize( - sdfg=sdfg, - gpu_block_size=gpu_block_size, - gpu_launch_bounds=gpu_launch_bounds, - gpu_launch_factor=gpu_launch_factor, - ) - - if gpu and (block_dim is not None): - sdfg.apply_transformations_repeated( + # Phase 5: Apply blocking + if block_dim is not None: + sdfg.apply_transformations_once_everywhere( gtx_transformations.KBlocking( blocking_size=blocking_size, block_dim=block_dim, ), - validate=True, + validate=validate, + validate_all=validate_all, + ) + + # Phase 6: Going to GPU + if gpu: + gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) + gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) + gtx_transformations.gt_gpu_transformation( + sdfg, + gpu_block_size=gpu_block_size, + gpu_launch_bounds=gpu_launch_bounds, + gpu_launch_factor=gpu_launch_factor, + validate=validate, + validate_all=validate_all, ) - # These are the part that we copy from DaCe built in auto optimization. + # Phase 7: General Optimizations + # The following operations apply regardless if we have a GPU or CPU. + # The DaCe auto optimizer also uses them. Note that the reuse transient + # is not done by DaCe. + device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU + transient_reuse = dace.transformation.passes.TransientReuse() + + transient_reuse.apply_pass(sdfg, {}) dace_aoptimize.set_fast_implementations(sdfg, device) + # TODO(phimuell): Fix the bug, it is used the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) - dace_aoptimize.make_transients_persistent(sdfg, device) - - # Final simplify - gt_simplify(sdfg) + if make_persistent: + # TODO(phimuell): Allow to also make them `SDFG`. + dace_aoptimize.make_transients_persistent(sdfg, device) return sdfg From 316ba9c34cfbe9a490f62e23dfa94b8be61f37c1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 26 Jul 2024 13:03:16 +0200 Subject: [PATCH 177/235] Fixed a bug in the `get_map_variable()` function. This function is responsible for generating the names of map parameters used in teh DaCe fieldview backend. It contained an error because it had only one underscore in front of `gtx`. Because there was another bug in the optimization pipeline, I did not notice it. --- .../next/program_processors/runners/dace_fieldview/utility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 8f775de3f3..324585adea 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -129,4 +129,4 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: Format map variable name based on the naming convention for application-specific SDFG transformations. """ suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" - return f"i_{dim.value}_gtx_{dim.kind}{suffix}" + return f"i_{dim.value}__gtx_{dim.kind}{suffix}" From a796766607d358309c10cb038ac7c1ffd25a7dc9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 26 Jul 2024 14:51:45 +0200 Subject: [PATCH 178/235] First batch of stuff for review. Only the serial fusion stuff is not tidied up. --- .../transformations/__init__.py | 6 + .../transformations/auto_opt.py | 8 ++ .../transformations/gpu_utils.py | 94 +++++++++---- .../transformations/k_blocking.py | 130 ++++++++++-------- .../transformations/map_orderer.py | 34 ++--- .../transformations/map_promoter.py | 109 +++++++++++---- .../dace_fieldview/transformations/util.py | 16 ++- 7 files changed, 261 insertions(+), 136 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index e297cd2237..e1172c2fce 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,6 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +"""Transformation and optimization pipeline for the DaCe backend in GT4Py. + +Please also see [this HackMD document](https://hackmd.io/@gridtools/rklwk4OIR#Requirements-on-SDFG) +that explains the general structure and requirements on the SDFG. +""" + from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify from .gpu_utils import ( GPUSetBlockSize, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index d2434e3f97..7d7d8e3971 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -208,7 +208,10 @@ def gt_auto_optimize( # TODO(phimuell): More control what we promote. phase2_cleanup.append( gtx_transformations.SerialMapPromoter( + only_toplevel_maps=True, + promote_vertical=True, promote_horizontal=False, + promote_local=False, ) ) @@ -268,6 +271,10 @@ def gt_auto_optimize( # Phase 6: Going to GPU if gpu: + # TODO(phimuell): The GPU function might modify the map iteration order. + # This is because how it is implemented (promotion and + # fusion). However, because of its current state, this + # should not happen, but we have to look into it. gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) gtx_transformations.gt_gpu_transformation( @@ -277,6 +284,7 @@ def gt_auto_optimize( gpu_launch_factor=gpu_launch_factor, validate=validate, validate_all=validate_all, + try_removing_trivial_maps=True, ) # Phase 7: General Optimizations diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index b5594fda16..c160fc52a4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -36,32 +36,50 @@ def gt_gpu_transformation( sdfg: dace.SDFG, - promote_serial_maps: bool = True, + try_removing_trivial_maps: bool = True, + use_gpu_storage: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, validate: bool = True, validate_all: bool = False, + **kwargs: Any, ) -> dace.SDFG: """Transform an SDFG into an GPU SDFG. - The transformations are done in place. - The function will roughly do the same: - - Move all arrays used as input to the GPU. - - Apply the standard DaCe GPU transformation. - - Run `gt_simplify()` (recommended by the DaCe documentation). - - Try to promote trivial maps. - - If given set the GPU block size. + The transformation expects a rather optimized SDFG and turn it into an SDFG + capable of running on the GPU. + The function performs the following steps: + - If requested, modify the storage location of the non transient arrays to + life on GPU. + - Call the normal GPU transform function followed by simplify. + - If requested try to remove trivial kernels. + - If given set the `gpu_block_size` parameters of the Maps to the given value. + + Args: + sdfg: The SDFG that should be processed. + try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. + use_gpu_storage: Assume that the non global memory is already on the GPU. + gpu_block_size: Set the GPU block size of all maps that does not have + one to this value. + + Notes: + The function might modify the order of the iteration variables of some + maps and fuse other Maps. + + Todo: + - Solve the fusing problem. + - Currently only one block size for all maps is given, add more options. """ + # You need guru level or above to use these arguments. + gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) + gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) + # Turn all global arrays (which we identify as input) into GPU memory. # This way the GPU transformation will not create this copying stuff. - for desc in sdfg.arrays.values(): - if desc.transient: - continue - if not isinstance(desc, dace.data.Array): - continue - desc.storage = dace.dtypes.StorageType.GPU_Global + if use_gpu_storage: + for desc in sdfg.arrays.values(): + if not (desc.transient or not isinstance(desc, dace.data.Array)): + desc.storage = dace.dtypes.StorageType.GPU_Global # Now turn it into a GPU SDFG sdfg.apply_gpu_transformations( @@ -69,14 +87,29 @@ def gt_gpu_transformation( validate_all=validate_all, simplify=False, ) - - # The documentation recommend to run simplify afterwards + # The documentation recommends to run simplify afterwards gtx_transformations.gt_simplify(sdfg) - # Start to promote the maps. - if promote_serial_maps: + if try_removing_trivial_maps: + # For reasons a Tasklet can not exist outside a Map in a GPU SDFG. The GPU + # transformation will thus adds trivial maps around them, which translate to + # a kernel launch. Our current solution is to promote them and then fuse it. + # NOTE: The current implementation has a flaw, because promotion and fusion + # are two different steps, this is is inefficient. There are some problems + # because the mapped Tasklet might not be fusable at all. However, the real + # problem is, that Map fusion does not guarantee a certain order of Map + # variables. Currently this is not a problem because of the way it is + # implemented. + # TODO(phimuell): Fix the issue described above. + sdfg.apply_transformations_once_everywhere( + gtx_transformations.SerialMapPromoterGPU(), + validate=False, + validate_all=False, + ) sdfg.apply_transformations_repeated( - [gtx_transformations.SerialMapPromoterGPU()], + gtx_transformations.SerialMapFusion( + only_toplevel_maps=True, + ), validate=validate, validate_all=validate_all, ) @@ -159,6 +192,8 @@ def _gpu_block_getter( class GPUSetBlockSize(transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. + It is also possible to set the launch bound. + Args: block_size: The block size that should be used. launch_bounds: The value for the launch bound that should be used. @@ -166,7 +201,7 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): in a block multiplied by this number. Todo: - Depending on the number of dimensions of a map, there should be different sources. + Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. """ block_size = properties.Property( @@ -255,13 +290,18 @@ class SerialMapPromoterGPU(transformation.SingleStateTransformation): """Serial Map promoter for empty Maps in case of trivial Maps. In CPU mode a Tasklet can be outside of a map, however, this is not - possible in CPU mode. For this reason DaCe wraps every such Tasklet - in a trivial Map. - This function will look for such Maps and promote them, such that they - can be fused with downstream maps. + possible in GPU mode. For this reason DaCe wraps such Tasklets in a + trivial Map. + This transformation will look for such Maps and promote them, such + that they can be fused with downstream maps. Note: This transformation must be run after the GPU Transformation. + + Todo: + - The transformation assumes that the upper Map is a trivial Tasklet. + Which should be the majority of all cases. + - Combine this transformation such that it can do serial fusion on its own. """ # Pattern Matching @@ -330,5 +370,3 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: map_1.params = copy.deepcopy(map_2.params) map_1.range = copy.deepcopy(map_2.range) - - return diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 6fa971ba92..bad54693af 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -27,20 +27,22 @@ @properties.make_properties class KBlocking(transformation.SingleStateTransformation): - """Performs K blocking. + """Applies k-Blocking with separation on a Map. - This transformation takes a multidimensional map and performs k blocking on - one particular dimension, which is identified by `block_dim`, which we also - call `k`. - All dimensions except `k` are unaffected and the transformation replaces it - with `kk` and the range `0:N:B`, where `N` is the original end of the - transformation and `B` is the block size, passed as `blocking_size`. - The transformation will then add an inner sequential map with one - dimension `k = kk:(kk+B)`. + This transformation takes a multidimensional Map and performs blocking on a + dimension, that is commonly called "k", but identified with `block_dim`. - Furthermore, the map will inspect the neighbours of the old, or outer map. - If the node does not depend on the blocked dimension, the node will be put - between the two maps, thus its content will only be computed once. + All dimensions except `k` are unaffected by this transformation. In the outer + Map the will replace the `k` range, currently `k = 0:N`, with + `__coarse_k = 0:N:B`, where `N` is the original size of the range and `B` + is the block size, passed as `blocking_size`. The transformation also handles the + case if `N % B != 0`. + The transformation will then create an inner sequential map with + `k = __coarse_k:(__coarse_k + B)`. + + However, before the split the transformation examines all adjacent nodes of + the original Map. If a node does not depend on `k`, then the node will be + put between the two maps, thus its content will only be computed once. The function will also change the name of the outer map, it will append `_blocked` to it. @@ -54,7 +56,7 @@ class KBlocking(transformation.SingleStateTransformation): block_dim = properties.Property( dtype=str, allow_none=True, - desc="Which dimension should be blocked.", + desc="Which dimension should be blocked (must be an exact match).", ) map_entry = transformation.transformation.PatternNode(nodes.MapEntry) @@ -85,9 +87,12 @@ def can_be_applied( ) -> bool: """Test if the map can be blocked. - For this the map: - - Must contain the block dimension. - - Must not be serial. + The test involves: + - Toplevel map. + - The map shall not be serial. + - The block dimension must be present (exact match). + - The map range must have stride one. + - The partition must exists (see `partition_map_output()`). """ map_entry: nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params @@ -101,10 +106,10 @@ def can_be_applied( return False if map_entry.map.schedule == dace.dtypes.ScheduleType.Sequential: return False - if self.partition_map_output(map_entry, block_var, graph, sdfg) is None: - return False if map_range[map_params.index(block_var)][2] != 1: return False + if self.partition_map_output(map_entry, block_var, graph, sdfg) is None: + return False return True @@ -113,7 +118,10 @@ def apply( graph: Union[SDFGState, SDFG], sdfg: SDFG, ) -> None: - """Performs the blocking transformation.""" + """Creates a blocking map. + + Performs the operation described in the doc string. + """ outer_entry: nodes.MapEntry = self.map_entry outer_exit: nodes.MapExit = graph.exit_node(outer_entry) outer_map: nodes.Map = outer_entry.map @@ -133,7 +141,7 @@ def apply( outer_entry, block_var, graph, sdfg ) - # Now generate the sequential inner map + # Generate the sequential inner map rng_start = map_range[block_idx][0] rng_stop = map_range[block_idx][1] inner_label = f"inner_{outer_map.label}" @@ -148,6 +156,8 @@ def apply( schedule=dace.dtypes.ScheduleType.Sequential, ) + # TODO(phimuell): Investigate if we want to prevent unrolling here + # Now we modify the properties of the outer map. coarse_block_range = subsets.Range.from_string( f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" @@ -160,18 +170,18 @@ def apply( relocated_nodes: set[nodes.Node] = set() # Now we iterate over all the output edges of the outer map and rewire them. - # Note that this only handles the map entry. + # Note that this only handles the entry of the Map. for out_edge in list(graph.out_edges(outer_entry)): edge_dst: nodes.Node = out_edge.dst if edge_dst in dependent_nodes: - # This is the simple case as we just ave to rewire the edge + # This is the simple case as we just have to rewire the edge # and make a connection between the outer and inner map. assert not out_edge.data.is_empty() edge_conn: str = out_edge.src_conn[4:] # Must be before the handling of the modification below - # Note that this will remove the edge from the SDFG. + # Note that this will remove the original edge from the SDFG. helpers.redirect_edge( state=graph, edge=out_edge, @@ -179,9 +189,7 @@ def apply( new_src_conn="OUT_" + edge_conn, ) - # In a valid SDFG one one edge can go to an input connector of - # a map (this is their definition), thus we have to adhere to - # this definition as well. + # In a valid SDFG only one edge can go into an input connector of a Map. if "IN_" + edge_conn in inner_entry.in_connectors: # We have found this edge multiple times already. # To ensure that there is no error, we will create a new @@ -202,31 +210,36 @@ def apply( "IN_" + edge_conn, copy.deepcopy(out_edge.data), ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) + inner_entry.add_in_connector("IN_" + edge_conn) + inner_entry.add_out_connector("OUT_" + edge_conn) continue elif edge_dst in relocated_nodes: - # See `else` case + # The node was already fully handled in the `else` clause. continue else: # Relocate the node and make the reconnection. - # Different from the dependent case we will handle the node fully, - # i.e. all of its edges will be processed in one go. + # Different from the dependent case we will handle all the edges + # of the node in one go. relocated_nodes.add(edge_dst) - # This is the node serving as the storage to store the independent - # data, and is used within the inner loop. - # This prevents the reloading of data. - assert graph.out_degree(edge_dst) == 1 + # In order to be useful we have to temporary store the data the + # independent node generates + assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift if isinstance(edge_dst, nodes.AccessNode): + # The independent node is an access node, so we can use it directly. caching_node: nodes.AccessNode = edge_dst else: + # The dependent node is not an access node. For now we will + # just use the next node, with some restriction. + # TODO(phimuell): create an access node in this case instead. caching_node = next(iter(graph.out_edges(edge_dst))).dst assert graph.in_degree(caching_node) == 1 - assert isinstance(caching_node, nodes.AccessNode) + assert isinstance(caching_node, nodes.AccessNode) + # Now rewire the Memlets that leave the caching node to go through + # new inner Map. for consumer_edge in list(graph.out_edges(caching_node)): new_map_conn = inner_entry.next_connector() helpers.redirect_edge( @@ -246,8 +259,9 @@ def apply( inner_entry.add_out_connector("OUT_" + new_map_conn) continue - # Now we have to handle the output of the map. - # There is not many to do just reconnect some edges. + # Handle the Map exits + # This is simple reconnecting, there would be possibilities for improvements + # but we do not use them for now. for out_edge in list(graph.in_edges(outer_exit)): edge_conn = out_edge.dst_conn[3:] helpers.redirect_edge( @@ -276,18 +290,17 @@ def partition_map_output( state: SDFGState, sdfg: SDFG, ) -> tuple[set[nodes.Node], set[nodes.Node]] | None: - """Partition the outputs + """Partition the outputs of the Map. - This function computes the partition of the intermediate outputs of the map. - It will compute two set: + The partition will only look at the direct intermediate outputs of the + Map. The outputs will be two sets, defined as: - The independent outputs `\mathcal{I}`: These are output nodes, whose output does not depend on the blocked - dimension. These nodes will be relocated between the outer and - inner map. + dimension. These nodes can be relocated between the outer and inner map. - The dependent output `\mathcal{D}`: These are the output nodes, whose output depend on the blocked dimension. - Thus they will not be relocated between the two maps, but remain in the - inner most scope. + Thus they can not be relocated between the two maps, but will remain + inside the inner scope. In case the function fails to compute the partition `None` is returned. @@ -308,9 +321,9 @@ def partition_map_output( # Find all nodes that are adjacent to the map entry. nodes_to_partition: set[nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} - # Now we examine every node and assign them to a set. + # Now we examine every node and assign them to one of the sets. # Note that this is only tentative and we will later inspect the - # outputs of the independent node and reevaluate the classification. + # outputs of the independent node and reevaluate their classification. for node in nodes_to_partition: # Filter out all nodes that we can not (yet) handle. if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): @@ -324,20 +337,19 @@ def partition_map_output( block_dependent.add(node) continue - # Only one output is allowed - # Might be less important for Tasklets but for AccessNodes. + # An independent node can (for now) only have one output. # TODO(phimuell): Lift this restriction. if state.out_degree(node) != 1: block_dependent.add(node) continue - # Now we have to understand how the node generates its information. - # for this we have to look at all the edges that feed information to it. + # Now we have to understand how the node generates its data. + # For this we have to look at all the edges that feed information to it. edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges(node)) # If all edges are empty, i.e. they are only needed to keep the - # node inside the scope, consider it as independent. If they are - # tied to different scopes, refuse to work. + # node inside the scope, consider it as independent. However, they have + # to be associated to the outer map. if all(edge.data.is_empty() for edge in edges): if not all(edge.src is map_entry for edge in edges): return None @@ -345,14 +357,15 @@ def partition_map_output( continue # Currently we do not allow that a node has a mix of empty and non - # empty Memlets, is all or nothing. + # empty Memlets, it is all or nothing. if any(edge.data.is_empty() for edge in edges): return None # If the node gets information from other nodes than the map entry - # we classify it as a dependent node, although there can be situations - # were it could still be an independent node, but figuring this out - # is too complicated. + # we classify it as a dependent node. But there can be situations where + # the node could still be independent, for example if it is connected + # to a independent node, then it could be independent itself. + # TODO(phimuell): Consider independent node as "equal" to the map. if any(edge.src is not map_entry for edge in edges): block_dependent.add(node) continue @@ -387,7 +400,7 @@ def partition_map_output( break else: # The loop ended normally, thus we did not found anything that made us - # believe that the node is _not_ an independent node. We will later + # _think_ that the node is _not_ an independent node. We will later # also inspect the output, which might reclassify the node block_independent.add(node) @@ -396,6 +409,7 @@ def partition_map_output( block_dependent.add(node) # We now make a last screening of the independent nodes. + # TODO(phimuell): Make an iterative process to find the maximal set. for independent_node in list(block_independent): if isinstance(independent_node, nodes.AccessNode): if state.in_degree(independent_node) != 1: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index 20bb46ee47..b6a900ce1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -26,21 +26,21 @@ class MapIterationOrder(transformation.SingleStateTransformation): """Modify the order of the iteration variables. - The transformation modifies the order in which the map variables are processed. - This transformation is restricted in the sense, that it is only possible - to set the "first map" variable, i.e. the one that is associated with the - `x` dimension in a thread block and where the input memory should have stride 1. + The iteration order, while irrelevant from an SDFG point of view, is highly + relevant in code, and the fastest varying index ("inner most loop" in CPU or + "x block dimension" in GPU) should be associated with the stride 1 dimension + of the array. + This transformation will reorder the map indexes such that this is the case. - If the transformation modifies then the map variable corresponding to - `self.leading_dim` will be at the correct place, the order of all other - map variable is unspecific. Otherwise the map is unmodified. + While the place of the leading dimension is clearly defined, the order of the + other loop indexes, after this transformation is unspecified. Args: leading_dim: A GT4Py dimension object that identifies the dimension that - should be used. + is supposed to have stride 1. Note: - This transformation does follow the rules outlines [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + The transformation does follow the rules outlines [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) Todo: - Extend that different dimensions can be specified to be leading @@ -85,24 +85,26 @@ def can_be_applied( if self.leading_dim is None: return False - map_entry: nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params map_var: str = dace_fieldview_util.get_map_variable(self.leading_dim) if map_var not in map_params: return False - if map_params[-1] == map_var: # Already at the end; `-1` is correct! + if map_params[-1] == map_var: # Already at the correct location return False return True - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: """Performs the actual parameter reordering. - The function will move the map variable, that corresponds to - `self.leading_dim` at the end. It will not put it at the front, because - DaCe.codegen processes the variables in revers order and smashes all - the excess parameter into the last CUDA dimension. + The function will make the map variable, that corresponds to + `self.leading_dim` the last map variable (this is given by the structure of + DaCe's code generator). """ map_entry: nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index b63eda5445..b0b5313b85 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -28,17 +28,23 @@ class BaseMapPromoter(transformation.SingleStateTransformation): """Base transformation to add certain missing dimension to a map. - By adding certain dimension to a map it will became possible to fuse them. - This class acts as a base and the actual matching and checking must be - implemented by a concrete implementation. - - In order to properly work, the parameters of `source_map` must be a strict - superset of the ones of `map_to_promote`. Furthermore, this transformation + By adding certain dimension to a Map, it might became possible to use the Map + in more transformations. This class acts as a base and the actual matching and + checking must be implemented by a concrete implementation. + But it provides some basic check functionality and the actual promotion logic. + + The transformation operates on two Maps, first the "source map". This map + describes the Map that should be used as template. The second one is "map to + promote". After the transformation the "map to promote" will have the same + map parameter than the "source map" has. + + In order to properly work, the parameters of "source map" must be a strict + superset of the ones of "map to promote". Furthermore, this transformation builds upon the structure defined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). Thus it only checks the name of the parameters. To influence what to promote the user must implement the `map_to_promote()` - and `source_map()` must be implemented. They have to return the map entry node. + and `source_map()` function. They have to return the map entry node. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. @@ -49,7 +55,6 @@ class BaseMapPromoter(transformation.SingleStateTransformation): Note: This ignores tiling. - This only works with constant sized maps. """ only_toplevel_maps = properties.Property( @@ -123,6 +128,10 @@ def __init__( self.promote_horizontal = bool(promote_horizontal) if only_inner_maps and only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + if not (self.promote_local or self.promote_vertical or self.promote_horizontal): + raise ValueError( + "You must select at least one class of dimension that should be promoted." + ) def can_be_applied( self, @@ -136,8 +145,7 @@ def can_be_applied( A subclass should call this function before checking anything else. If a subclass has not called this function, the behaviour will be undefined. The function checks: - - If the map to promote is in the right scope (it is not required that - the two maps are in the same scope). + - If the map to promote is in the right scope. - If the parameter of the second map are compatible with each other. - If a dimension would be promoted that should not. """ @@ -147,6 +155,10 @@ def can_be_applied( source_map: nodes.Map = source_map_entry.map # Test the scope of the promotee. + # Because of the nature of the transformation, it is not needed that the + # two maps are in the same scope. However, they should be in the same state + # to ensure that the symbols are the same and all. But this is guaranteed by + # the nature of this transformation (single state). if self.only_inner_maps or self.only_toplevel_maps: scopeDict: Mapping[nodes.Node, Union[nodes.Node, None]] = graph.scope_dict() if self.only_inner_maps and (scopeDict[map_to_promote_entry] is None): @@ -155,37 +167,41 @@ def can_be_applied( return False # Test if the map ranges are compatible with each other. - params_to_promote: list[str] | None = self.missing_map_params( + missing_map_parameters: list[str] | None = self.missing_map_params( map_to_promote=map_to_promote, source_map=source_map, be_strict=True, ) - if not params_to_promote: + if not missing_map_parameters: return False - # Now we must check if there are dimensions that we do not want to promote. - if (not self.promote_local) and any( - param.endswith("__gtx_localdim") for param in params_to_promote - ): - return False - if (not self.promote_vertical) and any( - param.endswith("__gtx_vertical") for param in params_to_promote - ): - return False - if (not self.promote_horizontal) and any( - param.endswith("__gtx_horizontal") for param in params_to_promote - ): + # We now know which dimensions we have to add to the promotee map. + # Now we must test if we are also allowed to make that promotion in the first place. + dimension_identifier: list[str] = [] + if self.promote_local: + dimension_identifier.append("__gtx_localdim") + if self.promote_vertical: + dimension_identifier.append("__gtx_vertical") + if self.promote_horizontal: + dimension_identifier.append("__gtx_horizontal") + if not dimension_identifier: return False + for missing_map_param in missing_map_parameters: + if not any( + missing_map_param.endswith(dim_identifier) + for dim_identifier in dimension_identifier + ): + return False return True def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: - """Performs the Map Promoting. + """Performs the actual Map promoting. Add all parameters that `self.source_map` has but `self.map_to_promote` lacks to `self.map_to_promote` the range of these new dimensions is taken from the source map. - The order of the parameters of these new dimensions is undetermined. + The order of the parameters the Map has after the promotion is unspecific. """ map_to_promote: nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map source_map: nodes.Map = self.source_map(state=graph, sdfg=sdfg).map @@ -255,7 +271,19 @@ def missing_map_params( @properties.make_properties class SerialMapPromoter(BaseMapPromoter): - """This class promotes serial maps, such that they can be fused.""" + """Promote a map such that it can be fused serially. + + A condition for fusing serial Maps is that they cover the same range. This + transformation is able to promote a Map, i.e. adding the missing dimensions, + such that the maps can be fused. + For more information see the `BaseMapPromoter` class. + + Notes: + The transformation does not perform the fusing on its one. + + Todo: + The map should do the fusing on its own directly. + """ # Pattern Matching exit_first_map = transformation.transformation.PatternNode(nodes.MapExit) @@ -279,6 +307,33 @@ def expressions(cls) -> Any: ), ] + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the Maps really can be fused.""" + from .map_seriall_fusion import SerialMapFusion + + if not super().can_be_applied(graph, expr_index, sdfg, permissive): + return False + + # Check if the partition exists, if not promotion to fusing is pointless. + # TODO(phimuell): Find the proper way of doing it. + serial_fuser = SerialMapFusion(only_toplevel_maps=True) + output_partition = serial_fuser.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.exit_first_map, + map_entry_2=self.entry_second_map, + ) + if output_partition is None: + return False + + return True + def map_to_promote( self, state: dace.SDFGState, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index f8596106c7..9e4f09c722 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Common functionality for the transformations.""" +"""Common functionality for the transformations/optimization pipeline.""" from typing import Iterable @@ -40,11 +40,12 @@ def all_nodes_between( end: nodes.Node, reverse: bool = False, ) -> set[nodes.Node] | None: - """Returns all nodes that are reachable from `begin` but bound by `end`. + """Find all nodes that are reachable from `begin` but bound by `end`. - What the function does is, that it starts a DFS starting at `begin`, which is - not part of the returned set, every edge that goes to `end` will be considered - to not exists. + Essentially the function starts a DFS at `begin`, which is never part of the + returned set, if at a node an edge is found that lead to `end`, the function + will ignore this edge. However, it will find every node that is reachable + from `begin` that is reachable by a path that does not visit `end`. In case `end` is never found the function will return `None`. If `reverse` is set to `True` the function will start exploring at `end` and @@ -99,8 +100,8 @@ def find_downstream_consumers( ) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Find all downstream connectors of `begin`. - A consumer, in this sense, is any node that is neither an entry nor an exit - node. The function returns a set storing the pairs, the first element is the + A consumer, in for this function, is any node that is neither an entry nor + an exit node. The function returns a set of pairs, the first element is the node that acts as consumer and the second is the edge that leads to it. By setting `only_tasklets` the nodes the function finds are only Tasklets. @@ -149,6 +150,7 @@ def find_downstream_consumers( target_conn = curr_edge.dst_conn[3:] new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) to_visit.extend(new_edges) + del new_edges else: if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): continue From 882ad44e9f6fd36e705bf08d4843daeb9e0fb85c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 26 Jul 2024 15:49:03 +0200 Subject: [PATCH 179/235] Also checked the map fusion helper stuff. --- .../transformations/map_fusion_helper.py | 125 +++++++++++------- 1 file changed, 75 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 138164bdb4..f4943949c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -23,19 +23,19 @@ from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes from dace.transformation import helpers -from . import util +from gt4py.next.program_processors.runners.dace_fieldview.transformations import util @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): """Contains common part of the map fusion for parallel and serial map fusion. - The transformation assumes that the SDFG obeys the principals outlined in [this - HackMD document](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). The main advantage of this structure is, that it is rather easy to determine - if a transient can be used. This check, performed by `is_interstate_transient()`, - is speed up by cashing some computation, thus such an object should not be used - after interstate optimizations were applied to the SDFG. + if a transient is used anywhere else. This check, performed by + `is_interstate_transient()`. It is further speeded up by cashing some computation, + thus such an object should not be used after interstate optimizations were applied + to the SDFG. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. @@ -88,25 +88,32 @@ def can_be_fused( sdfg: dace.SDFG, permissive: bool = False, ) -> bool: - """Performs some checks if the maps can be fused. + """Performs basic checks if the maps can be fused. - Essentially, this function only checks constrains that are common between - the serial and parallel map fusion process. It tests: + This function only checks constrains that are common between serial and + parallel map fusion process. It tests: - The scope of the maps. - The scheduling of the maps. - The map parameters. However, for performance reasons, the function does not check if the node decomposition exists. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. """ if self.only_inner_maps and self.only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") - # ensure that both have the same schedule + # Ensure that both have the same schedule if map_entry_1.map.schedule != map_entry_2.map.schedule: return False - # Fusing is only possible if our two entries are in the same scope. + # Fusing is only possible if the two entries are in the same scope. scope = graph.scope_dict() if scope[map_entry_1] != scope[map_entry_2]: return False @@ -119,7 +126,7 @@ def can_be_fused( elif util.is_nested_sdfg(sdfg): return False - # We will now check if there exists a remapping that we can use. + # We will now check if there exists a "remapping" that we can use. if not self.map_parameter_compatible( map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg ): @@ -136,12 +143,15 @@ def relocate_nodes( ) -> None: """Move the connectors and edges from `from_node` to `to_nodes` node. - Note: - - This function dos not remove the `from_node` but it will have degree - zero and have no connectors. - - If this function fails, the SDFG is in an invalid state. - - Usually this function should be called twice per Map scope, once for the - entry node and once for the exit node. + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entries and then for the exits. + + Args: + from_node: Node from which the edges should be removed. + to_node: Node to which the edges should reconnect. + state: The state in which the operation happens. + sdfg: The SDFG that is modified. """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` @@ -158,9 +168,9 @@ def relocate_nodes( state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) - # We now determine if which connections we have to migrate - # We only consider the in edges, for Map exits it does not matter, but for - # Map entries, we need it for the dynamic map range feature. + # We now determine if which connections we have to migrate. We only consider + # the in edges, for Map exits it does not matter, but for Map entries, + # we need it for the dynamic map range feature. for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) @@ -168,10 +178,9 @@ def relocate_nodes( # Dynamic Map Range # The connector name simply defines a variable name that is used, # inside the Map scope to define a variable. We handle it directly. - assert isinstance(from_node, nodes.MapEntry) dmr_symbol = edge_to_move.dst_conn - # TODO(phimuell): Check if the symbol is really unused. + # TODO(phimuell): Check if the symbol is really unused in the target scope. if dmr_symbol in to_node.in_connectors: raise NotImplementedError( f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" @@ -188,8 +197,8 @@ def relocate_nodes( # There is no other edge that we have to consider, so we just end here continue - # We have a Passthrough connection, i.e. there exists a `OUT_` connector - # thus we now have to migrate the two edges. + # We have a Passthrough connection, i.e. there exists a matching `OUT_` + # connector thus we now have to migrate the two edges. old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix new_conn = to_node.next_connector(old_conn) @@ -214,7 +223,7 @@ def map_parameter_compatible( state: Union[SDFGState, SDFG], sdfg: SDFG, ) -> bool: - """Checks if `map_1` is compatible with `map_1`. + """Checks if the parameters of `map_1` are compatible with `map_1`. The check follows the following rules: - The names of the map variables must be the same, i.e. no renaming @@ -239,7 +248,7 @@ def map_parameter_compatible( for pname in params_1: idx_1 = param_dim_map_1[pname] idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify + # TODO(phimuell): do we need to call simplify? if range_1[idx_1] != range_2[idx_2]: return False @@ -249,13 +258,15 @@ def is_interstate_transient( self, transient: Union[str, nodes.AccessNode], sdfg: dace.SDFG, - state: dace.SDFGState | None = None, + state: dace.SDFGState, ) -> bool: """Tests if `transient` is an interstate transient, an can not be removed. Essentially this function checks if a transient might be needed in a different state in the SDFG, because it transmit information from - one state to the other. If only the name of the transient is passed, + one state to the other. + + If only the name of the transient is passed, then the function will only check if it is used in another state. If the access node and the state are passed the function will also check if it is used inside the state. @@ -285,11 +296,11 @@ def is_interstate_transient( # If a scalar is not a source node then it is not included in this set. # Thus we do not have to look for it, instead we will check for them # explicitly. - def decent(node: nodes.Node, graph: Any) -> bool: + def go_deeper(node: nodes.Node, graph: Any) -> bool: return not isinstance(node, nodes.NestedSDFG) shared_sdfg_transients = set() - # TODO(phimuell): use `sdfg.all_nodes_recursive(decent)` if it is available. + # TODO(phimuell): use `sdfg.all_nodes_recursive(go_deeper)` if it is available. for state in sdfg.all_states(): shared_sdfg_transients.update( filter( @@ -300,18 +311,28 @@ def decent(node: nodes.Node, graph: Any) -> bool: ) self.shared_transients[sdfg] = shared_sdfg_transients - if isinstance(transient, nodes.AccessNode): - if state is not None: - # Rule 8: Used within the state. - if state.out_degree(transient) > 1: - return True - transient = transient.data + if isinstance(transient, str): + name = transient + matching_access_nodes = [node for node in state.data_nodes() if node.data == name] + # Rule 8: There is only one access node per state for data. + assert len(matching_access_nodes) == 1 + transient = matching_access_nodes[0] + + else: + assert isinstance(transient, nodes.AccessNode) + name = transient.data - desc: data.Data = sdfg.arrays[transient] + desc: data.Data = sdfg.arrays[name] if not desc.transient: return True if isinstance(desc, data.Scalar): + return True # Scalars can not be removed by fusion anyway. + + # Rule 8: If degree larger than one then it is used within the state. + if state.out_degree(transient) > 1: return True + + # Now we check if it is used in a different state. return transient in shared_sdfg_transients def partition_first_outputs( @@ -337,14 +358,14 @@ def partition_first_outputs( These edges exits the first map and does not enter the second map. These outputs will be simply be moved to the output of the second map. - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit and enters an access node, from + Edges in this set leaves the first map exit, enters an access node, from where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not needed anywhere else, thus it will + referenced by this access node is not used anywhere else, thus it can be removed. - Shared Intermediate Set `\mathbb{S}`: These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and are recreated - as output of the second map. + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. Returns: If such a decomposition exists the function will return the three sets @@ -398,8 +419,8 @@ def partition_first_outputs( # # The following tests are _after_ we have determined if we have a pure - # output node for a reason, as this allows us to handle more exotic - # pure node cases, as handling them is essentially rerouting an edge. + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge. # In case the intermediate has more than one entry, all must come from the # first map, otherwise we can not fuse them. Currently we restrict this @@ -436,7 +457,7 @@ def partition_first_outputs( # the case `b = a + 1; return b + 2`, where we have arrays. In this # example only a single element must be available to the second map. # However, this is hard to check so we will make a simplification. - # First we will not check it at the producer, but at the consumer point. + # First, we will not check it at the producer, but at the consumer point. # There we assume if the consumer does _not consume the whole_ # intermediate array, then we can decompose the intermediate, by setting # the map iteration index to zero and recover the shape, see @@ -450,11 +471,12 @@ def partition_first_outputs( if len(downstream_nodes) == 0: # There is nothing between intermediate node and the entry of the # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`, to which one depends on how it is used. + # `\mathbb{E}`. # This is a very special situation, i.e. the access node has many # different connections to the second map entry, this is a special - # case that we do not handle, instead simplify should be called. + # case that we do not handle. + # TODO(phimuell): Handle this case. if state.out_degree(intermediate_node) != 1: return None @@ -484,9 +506,9 @@ def partition_first_outputs( continue else: - # These is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connection, thus - # the node might belong to the shared outputs. Of the many different + # There is not only a single connection from the intermediate node to + # the second map, but the intermediate has more connections, thus + # the node might belong to the shared output. Of the many different # possibilities, we only consider a single case: # - The intermediate has a single connection to the second map, that # fulfills the restriction outlined above. @@ -516,4 +538,7 @@ def partition_first_outputs( continue assert exclusive_outputs or shared_outputs or pure_outputs + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) return (pure_outputs, exclusive_outputs, shared_outputs) From a0bf2639b51d0ac6d13bf24efb2ec06739ed0301 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Sun, 28 Jul 2024 14:25:57 +0200 Subject: [PATCH 180/235] Made the reuse of transients optional and disabled it. I know from NOAA that they use it, but we should run more tests on it. Probably write our own. --- .../dace_fieldview/transformations/auto_opt.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 7d7d8e3971..d4f6efe73b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -91,6 +91,7 @@ def gt_auto_optimize( gpu_block_size: Optional[Sequence[int | str] | str] = None, block_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + reuse_transients: bool = False, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -126,7 +127,7 @@ def gt_auto_optimize( - Use fast implementation for library nodes. - Move small transients to stack. - Make transients persistent (if requested). - - Reuse transients. + - If requested reuse transients. Args: sdfg: The SDFG that should ve optimized in place. @@ -141,6 +142,7 @@ def gt_auto_optimize( one for all. block_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. + reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -155,6 +157,7 @@ def gt_auto_optimize( not if we have too much internal space (register pressure). - Create a custom array elimination pass that honor rule 1. """ + device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) @@ -291,10 +294,13 @@ def gt_auto_optimize( # The following operations apply regardless if we have a GPU or CPU. # The DaCe auto optimizer also uses them. Note that the reuse transient # is not done by DaCe. - device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU - transient_reuse = dace.transformation.passes.TransientReuse() + if reuse_transients: + # TODO(phimuell): Investigate if we should enable it, it makes stuff + # harder for the compiler. Maybe write our own that + # only consider big transients and not small ones (~60B) + transient_reuse = dace.transformation.passes.TransientReuse() + transient_reuse.apply_pass(sdfg, {}) - transient_reuse.apply_pass(sdfg, {}) dace_aoptimize.set_fast_implementations(sdfg, device) # TODO(phimuell): Fix the bug, it is used the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) From 37392fd431cf4c845cb6131a7a68cc5e0e0fed1e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 09:43:58 +0200 Subject: [PATCH 181/235] First PR candidate for the optimization pipeline. It is all very basic with a lot of ToDos. The tests are not yet there (there are, but they are not that suitable). I will add them at a later stage. --- .../transformations/k_blocking.py | 17 +++-- .../transformations/map_fusion_helper.py | 62 +++++++++---------- .../transformations/map_orderer.py | 12 ++-- .../transformations/map_seriall_fusion.py | 47 +++++--------- 4 files changed, 62 insertions(+), 76 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index bad54693af..2227c82729 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -14,7 +14,7 @@ import copy import functools -from typing import Any, Union +from typing import Any, Optional, Union import dace from dace import properties, subsets, transformation @@ -63,16 +63,18 @@ class KBlocking(transformation.SingleStateTransformation): def __init__( self, - blocking_size: int, - block_dim: Union[gtx_common.Dimension, str], + blocking_size: Optional[int] = None, + block_dim: Optional[Union[gtx_common.Dimension, str]] = None, ) -> None: super().__init__() - self.blocking_size = blocking_size if isinstance(block_dim, str): pass elif isinstance(block_dim, gtx_common.Dimension): block_dim = dace_fieldview_util.get_map_variable(block_dim) - self.block_dim = block_dim + if block_dim is not None: + self.block_dim = block_dim + if blocking_size is not None: + self.blocking_size = blocking_size @classmethod def expressions(cls) -> Any: @@ -94,6 +96,11 @@ def can_be_applied( - The map range must have stride one. - The partition must exists (see `partition_map_output()`). """ + if self.block_dim is None: + raise ValueError("The blocking dimension was not specified.") + elif self.blocking_size is None: + raise ValueError("The blocking size was not specified.") + map_entry: nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params map_range: subsets.Range = map_entry.map.range diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index f4943949c1..265e64965b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -28,7 +28,7 @@ @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): - """Contains common part of the map fusion for parallel and serial map fusion. + """Contains common part of the fusion for parallel and serial Map fusion. The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). The main advantage of this structure is, that it is rather easy to determine @@ -46,7 +46,7 @@ class MapFusionHelper(transformation.SingleStateTransformation): dtype=bool, default=False, allow_none=False, - desc="Only perform fusing if the Maps are on the top level.", + desc="Only perform fusing if the Maps are in the top level.", ) only_inner_maps = properties.Property( dtype=bool, @@ -91,7 +91,7 @@ def can_be_fused( """Performs basic checks if the maps can be fused. This function only checks constrains that are common between serial and - parallel map fusion process. It tests: + parallel map fusion process, which includes: - The scope of the maps. - The scheduling of the maps. - The map parameters. @@ -123,6 +123,7 @@ def can_be_fused( elif self.only_toplevel_maps: if scope[map_entry_1] is not None: return False + # TODO(phimuell): Figuring out why this is here. elif util.is_nested_sdfg(sdfg): return False @@ -145,7 +146,7 @@ def relocate_nodes( This function will only rewire the edges, it does not remove the nodes themselves. Furthermore, this function should be called twice per Map, - once for the entries and then for the exits. + once for the entry and then for the exit. Args: from_node: Node from which the edges should be removed. @@ -168,9 +169,8 @@ def relocate_nodes( state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) - # We now determine if which connections we have to migrate. We only consider - # the in edges, for Map exits it does not matter, but for Map entries, - # we need it for the dynamic map range feature. + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) @@ -197,19 +197,18 @@ def relocate_nodes( # There is no other edge that we have to consider, so we just end here continue - # We have a Passthrough connection, i.e. there exists a matching `OUT_` - # connector thus we now have to migrate the two edges. + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix new_conn = to_node.next_connector(old_conn) to_node.add_in_connector("IN_" + new_conn) - from_node.remove_in_connector("IN_" + old_conn) for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - from_node.remove_out_connector("OUT_" + old_conn) to_node.add_out_connector("OUT_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) assert state.in_degree(from_node) == 0 assert len(from_node.in_connectors) == 0 @@ -223,7 +222,7 @@ def map_parameter_compatible( state: Union[SDFGState, SDFG], sdfg: SDFG, ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_1`. + """Checks if the parameters of `map_1` are compatible with `map_2`. The check follows the following rules: - The names of the map variables must be the same, i.e. no renaming @@ -236,7 +235,8 @@ def map_parameter_compatible( params_2: Sequence[str] = map_2.params # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. + # this is because we do not do any renaming. This is in accordance with the + # rules. if set(params_1) != set(params_2): return False @@ -265,11 +265,11 @@ def is_interstate_transient( Essentially this function checks if a transient might be needed in a different state in the SDFG, because it transmit information from one state to the other. + If only the name of the data container is passed the function will + first look for an corresponding access node. - If only the name of the transient is passed, - then the function will only check if it is used in another state. - If the access node and the state are passed the function will also - check if it is used inside the state. + The set of these "interstate transients" is computed once per SDFG. + The result is then cached internally for later reuse. Args: transient: The transient that should be checked. @@ -292,16 +292,10 @@ def is_interstate_transient( shared_sdfg_transients: set[str] = self.shared_transients[sdfg] else: - # SDFG is not known so we have to compute it. - # If a scalar is not a source node then it is not included in this set. - # Thus we do not have to look for it, instead we will check for them - # explicitly. - def go_deeper(node: nodes.Node, graph: Any) -> bool: - return not isinstance(node, nodes.NestedSDFG) - + # SDFG is not known so we have to compute the set. shared_sdfg_transients = set() - # TODO(phimuell): use `sdfg.all_nodes_recursive(go_deeper)` if it is available. for state in sdfg.all_states(): + # TODO(phimuell): Use `all_nodes_recursive()` once it is available. shared_sdfg_transients.update( filter( lambda node: isinstance(node, nodes.AccessNode) @@ -371,7 +365,7 @@ def partition_first_outputs( If such a decomposition exists the function will return the three sets mentioned above in the same order. In case the decomposition does not exist, i.e. the maps can not be fused - serially, the function returns `None`. + the function returns `None`. Args: state: The in which the two maps are located. @@ -397,11 +391,6 @@ def partition_first_outputs( return None processed_inter_nodes.add(intermediate_node) - # Empty Memlets are currently not supported. - # However, they are much more important in entry nodes. - if out_edge.data.is_empty(): - return None - # Now let's look at all nodes that are downstream of the intermediate node. # This, among other things, will tell us, how we have to handle this node. downstream_nodes = util.all_nodes_between( @@ -416,17 +405,22 @@ def partition_first_outputs( if downstream_nodes is None: pure_outputs.add(out_edge) continue - # # The following tests are _after_ we have determined if we have a pure # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge. + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None # In case the intermediate has more than one entry, all must come from the # first map, otherwise we can not fuse them. Currently we restrict this # even further by saying that it has only one incoming Memlet. if state.in_degree(intermediate_node) != 1: - # TODO(phimuell): handle this case. return None # It can happen that multiple edges converges at the `IN_` connector diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index b6a900ce1a..d42ff06edc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -49,7 +49,7 @@ class MapIterationOrder(transformation.SingleStateTransformation): """ leading_dim = properties.Property( - dtype=gtx_common.Dimension, + dtype=str, allow_none=True, desc="Dimension that should become the leading dimension.", ) @@ -58,12 +58,14 @@ class MapIterationOrder(transformation.SingleStateTransformation): def __init__( self, - leading_dim: Optional[gtx_common.Dimension] = None, + leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - if leading_dim is not None: + if isinstance(leading_dim, gtx_common.Dimension): + self.leading_dim = dace_fieldview_util.get_map_variable(leading_dim) + elif leading_dim is not None: self.leading_dim = leading_dim @classmethod @@ -87,7 +89,7 @@ def can_be_applied( return False map_entry: nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params - map_var: str = dace_fieldview_util.get_map_variable(self.leading_dim) + map_var: str = self.leading_dim if map_var not in map_params: return False @@ -108,7 +110,7 @@ def apply( """ map_entry: nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params - map_var: str = dace_fieldview_util.get_map_variable(self.leading_dim) + map_var: str = self.leading_dim # This implementation will just swap the variable that is currently the last # with the one that should be the last. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py index b4fc6b5985..66be18e6c0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py @@ -52,8 +52,6 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): Notes: - This transformation modifies more nodes than it matches! - - The consolidate edge transformation (part of simplify) is probably - harmful to the applicability of this transformation. """ map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) @@ -73,7 +71,8 @@ def expressions(cls) -> Any: The transformation matches the exit node of the top Map that is connected to an access node that again is connected to the entry node of the second Map. An important note is, that the transformation operates not just on these nodes, - but more or less anything that has an outgoing connection of the first Map. + but more or less anything that has an outgoing connection of the first Map, + and is connected to the second map. """ return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] @@ -131,7 +130,7 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non """ # NOTE: `self.map_*` actually stores the ID of the node. # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here. + # Thus we have to save them here, this is a known behaviour in DaCe. assert isinstance(graph, dace.SDFGState) assert isinstance(self.map_exit1, nodes.MapExit) assert isinstance(self.map_entry2, nodes.MapEntry) @@ -149,7 +148,6 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non assert output_partition is not None # Make MyPy happy. pure_outputs, exclusive_outputs, shared_outputs = output_partition - # Handling the outputs if len(exclusive_outputs) != 0: self.handle_intermediate_set( intermediate_outputs=exclusive_outputs, @@ -210,10 +208,8 @@ def handle_intermediate_set( The function is able to handle both the shared and exclusive intermediate output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode is that the intermediate node will be fully removed from + in exclusive mode the intermediate nodes will be fully removed from the SDFG. While in shared mode the intermediate node will be preserved. - However, the function just performs some rewiring of the outputs and - manipulation of the intermediate node set. Args: intermediate_outputs: The set of outputs, that should be processed. @@ -226,10 +222,7 @@ def handle_intermediate_set( Notes: Before the transformation the `state` does not be to be valid and - after this function has run the state is invalid. - The function is static and the map nodes have to be explicitly passed - because the `self.map_*` properties are modified by the modification. - This is a known behaviour in DaCe. + after this function has run the state is (most likely) invalid. Todo: Rewrite using `MemletTree`. @@ -244,20 +237,18 @@ def handle_intermediate_set( # Note that this is still not enough as the dimensionality might be different. memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - # Now we will iterate over all intermediate edge and process them. + # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. for out_edge in intermediate_outputs: - # This is the intermediate node that, depending on the mode, we want - # to get rid off, in shared mode it materialize, but now at the - # exit of the second Map, in exclusive mode it will be removed. + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. inter_node: nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) inter_shape = inter_desc.shape - # Now we will determine the shape of the new intermediate, which has some - # issue. The size of this temporary is given by the Memlet that - # goes into the first map exit. + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. pre_exit_edges = list( state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) ) @@ -268,7 +259,7 @@ def handle_intermediate_set( # Over approximation will leave us with some unneeded size one dimensions. # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] + squeezed_dims: list[int] = [] # These are the dimensions we removed. new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. for dim, (proposed_dim_size, full_dim_size) in enumerate( zip(new_inter_shape_raw, inter_shape) @@ -434,14 +425,6 @@ def handle_intermediate_set( map_entry_2.remove_in_connector(in_conn_name) map_entry_2.remove_out_connector(out_conn_name) - # TODO: Apply this modification to Memlets - # for neighbor in state.all_edges(local_node): - # for e in state.memlet_tree(neighbor): - # if e.data.data == local_name: - # continue # noqa: ERA001 - # e.data.data = local_name # noqa: ERA001 - # e.data.subset.offset(old_edge.data.subset, negative=True) # noqa: ERA001 - if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. assert state.degree(inter_node) == 1 @@ -459,10 +442,10 @@ def handle_intermediate_set( state.remove_edge(pre_exit_edge) map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - # This is the Memlet that goes from the map internal intermediate temporary node to the Map output. - # This will essentially restore or preserve the output for the intermediate node. - # It is important that we use the data that `preExitEdge` was used. - # On CPU it works but for some reasons not on GPU. + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. new_exit_memlet = copy.deepcopy(pre_exit_edge.data) assert new_exit_memlet.data == inter_name new_exit_memlet.subset = pre_exit_edge.data.dst_subset From 5c92c760ce54ebf22d3926c406d6158b5d62b369 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 13:33:59 +0200 Subject: [PATCH 182/235] Made a small fix in the test function if teh intermnediate was correct. --- .../transformations/map_fusion_helper.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 265e64965b..bd6328c396 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -290,18 +290,19 @@ def is_interstate_transient( # See if we have already computed the set if sdfg in self.shared_transients: shared_sdfg_transients: set[str] = self.shared_transients[sdfg] - else: # SDFG is not known so we have to compute the set. shared_sdfg_transients = set() - for state in sdfg.all_states(): + for state_to_scan in sdfg.all_states(): # TODO(phimuell): Use `all_nodes_recursive()` once it is available. shared_sdfg_transients.update( - filter( - lambda node: isinstance(node, nodes.AccessNode) - and sdfg.arrays[node.data].transient, - itertools.chain(state.source_nodes(), state.sink_nodes()), - ) + [ + node.data + for node in itertools.chain( + state_to_scan.source_nodes(), state_to_scan.sink_nodes() + ) + if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient + ] ) self.shared_transients[sdfg] = shared_sdfg_transients @@ -311,7 +312,6 @@ def is_interstate_transient( # Rule 8: There is only one access node per state for data. assert len(matching_access_nodes) == 1 transient = matching_access_nodes[0] - else: assert isinstance(transient, nodes.AccessNode) name = transient.data @@ -327,7 +327,7 @@ def is_interstate_transient( return True # Now we check if it is used in a different state. - return transient in shared_sdfg_transients + return name in shared_sdfg_transients def partition_first_outputs( self, From f4f5ae5b2c869f25b83056dc78ab05112b7df0ab Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 13:36:06 +0200 Subject: [PATCH 183/235] Added the first series of tests for teh serial map fusion. --- .../transformation_tests/__init__.py | 14 + .../transformation_tests/test_map_fusion.py | 400 ++++++++++++++++++ 2 files changed, 414 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py new file mode 100644 index 0000000000..67bee9d721 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py @@ -0,0 +1,14 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py new file mode 100644 index 0000000000..bd0a7fe066 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py @@ -0,0 +1,400 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, Union, overload, Literal, Generator + +import pytest +import dace +import copy +import numpy as np +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +@pytest.fixture(autouse=True) +def _set_dace_settings() -> Generator[None, None, None]: + """Enables the correct settings in DaCe.""" + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + yield + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[False], +) -> int: ... + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[True], +) -> list[dace_nodes.Node]: ... + + +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: bool = False, +) -> Union[int, list[dace_nodes.Node]]: + """Counts the number of nodes in of a particular type in `graph`. + + If `graph` is an SDFGState then only count the nodes inside this state, + but if `graph` is an SDFG count in all states. + + Args: + graph: The graph to scan. + node_type: The type or sequence of types of nodes to look for. + """ + + states = graph.states() if isinstance(graph, dace.SDFG) else [graph] + found_nodes: list[dace_nodes.Node] = [] + for state_nodes in states: + for node in state_nodes.nodes(): + if isinstance(node, node_type): + found_nodes.append(node) + if return_nodes: + return found_nodes + return len(found_nodes) + + +def _make_serial_sdfg_1( + N: str | int, +) -> dace.SDFG: + """Create the "serial_1_sdfg". + + This is an SDFG with a single state containing two maps. It has the input + `a` and the output `b`, each two dimensional arrays, with size `0:N`. + The first map adds 1 to the input and writes it into `tmp`. The second map + adds another 3 to `tmp` and writes it back inside `b`. + + Args: + N: The size of the arrays. + """ + shape = (N, N) + sdfg = dace.SDFG("serial_1_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in ["a", "b", "tmp"]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + tmp = state.add_access("tmp") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in0 + 1.0", + outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="second_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp": tmp}, + inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, + code="__out = __in0 + 3.0", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + + return sdfg + + +def _make_serial_sdfg_2( + N: str | int, +) -> dace.SDFG: + """Create the "serial_2_sdfg". + + The generated SDFG uses `a` and input and has two outputs `b := a + 4` and + `c := a - 4`. There is a top map with a single Single Tasklet, that has + two outputs, the first one computes `a + 1` and stores that in `tmp_1`. + The second output computes `a - 1` and stores it `tmp_2`. + Below the top map are two (parallel) map, one compute `b := tmp_1 + 3`, while + the other compute `c := tmp_2 - 3`. This means that there are two map fusions. + The main important thing is that, the second map fusion will involve a pure + fusion (because the processing order is indeterministic, one does not know + which one in advance). + + Args: + N: The size of the arrays. + """ + shape = (N, N) + sdfg = dace.SDFG("serial_2_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in ["a", "b", "c", "tmp_1", "tmp_2"]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp_1"].transient = True + sdfg.arrays["tmp_2"].transient = True + tmp_1 = state.add_access("tmp_1") + tmp_2 = state.add_access("tmp_2") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, + code="__out0 = __in0 + 1.0\n__out1 = __in0 - 1.0", + outputs={ + "__out0": dace.Memlet("tmp_1[__i0, __i1]"), + "__out1": dace.Memlet("tmp_2[__i0, __i1]"), + }, + output_nodes={ + "tmp_1": tmp_1, + "tmp_2": tmp_2, + }, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp_1": tmp_1}, + inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, + code="__out = __in0 + 3.0", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state.add_mapped_tasklet( + name="second_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp_2": tmp_2}, + inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, + code="__out = __in0 - 3.0", + outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + external_edges=True, + ) + + return sdfg + + +def test_exclusive_itermediate(): + """Tests if the exclusive intermediate branch works.""" + N = 10 + sdfg = _make_serial_sdfg_1(N) + + # Now apply the optimizations. + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert "tmp" not in sdfg.arrays + + # Test if the intermediate is a scalar + intermediate_nodes: list[dace_nodes.Node] = [ + node + for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) + if node.data not in ["a", "b"] + ] + assert len(intermediate_nodes) == 1 + assert all(isinstance(node.desc(sdfg), dace.data.Scalar) for node in intermediate_nodes) + + a = np.random.rand(N, N) + b = np.empty_like(a) + ref = a + 4.0 + sdfg(a=a, b=b) + + assert np.allclose(b, ref) + + +def test_shared_itermediate(): + """Tests the shared intermediate path. + + The function uses the `_make_serial_sdfg_1()` SDFG. However, it promotes `tmp` + to a global, and it thus became a shared intermediate, i.e. will survive. + """ + N = 10 + sdfg = _make_serial_sdfg_1(N) + sdfg.arrays["tmp"].transient = False + + # Now apply the optimizations. + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert "tmp" in sdfg.arrays + + # Test if the intermediate is a scalar + intermediate_nodes: list[dace_nodes.Node] = [ + node + for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) + if node.data not in ["a", "b", "tmp"] + ] + assert len(intermediate_nodes) == 1 + assert all(isinstance(node.desc(sdfg), dace.data.Scalar) for node in intermediate_nodes) + + a = np.random.rand(N, N) + b = np.empty_like(a) + tmp = np.empty_like(a) + + ref_b = a + 4.0 + ref_tmp = a + 1.0 + sdfg(a=a, b=b, tmp=tmp) + + assert np.allclose(b, ref_b) + assert np.allclose(tmp, ref_tmp) + + +def test_pure_output_node(): + """Tests the path of a pure intermediate.""" + N = 10 + sdfg = _make_serial_sdfg_2(N) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 3 + + # The first fusion will only bring it down to two maps. + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + + a = np.random.rand(N, N) + b = np.empty_like(a) + c = np.empty_like(a) + ref_b = a + 4.0 + ref_c = a - 4.0 + sdfg(a=a, b=b, c=c) + + assert np.allclose(b, ref_b) + assert np.allclose(c, ref_c) + + +def test_array_intermediate(): + """Tests the correct working if we have more than scalar intermediate. + + The test used `_make_serial_sdfg_1()` to get an SDFG and then call `MapExpansion`. + Map fusion is then called only outer maps, thus the intermediate node, must + be an array. + """ + N = 10 + sdfg = _make_serial_sdfg_1(N) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 4 + + # Now perform the fusion + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + validate=True, + validate_all=True, + ) + map_entries = _count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + + scope = next(iter(sdfg.states())).scope_dict() + assert len(map_entries) == 3 + top_maps = [map_entry for map_entry in map_entries if scope[map_entry] is None] + assert len(top_maps) == 1 + top_map = top_maps[0] + assert sum(scope[map_entry] is top_map for map_entry in map_entries) == 2 + + # Find the access node that is the new intermediate node. + inner_access_nodes: list[dace_nodes.AccessNode] = [ + node for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None + ] + assert len(inner_access_nodes) == 1 + inner_access_node = inner_access_nodes[0] + inner_desc: dace.data.Data = inner_access_node.desc(sdfg) + assert inner_desc.shape == (N,) + + a = np.random.rand(N, N) + b = np.empty_like(a) + ref_b = a + 4.0 + sdfg(a=a, b=b) + + assert np.allclose(ref_b, b) + + +def test_interstate_transient(): + """Tests if an interstate transient is handled properly. + + This function uses the SDFG generated by `_make_serial_sdfg_2()`. It adds a second + state to SDFG in which `tmp_1` is read from and the result is written in `d` (new + variable). Thus `tmp_1` can not be removed. + """ + N = 10 + sdfg = _make_serial_sdfg_2(N) + assert _count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert sdfg.number_of_nodes() == 1 + + # Now add the new state and the new output. + sdfg.add_datadesc("d", copy.deepcopy(sdfg.arrays["b"])) + head_state = next(iter(sdfg.states())) + new_state = sdfg.add_state_after(head_state) + + new_state.add_mapped_tasklet( + name="first_computation_second_state", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, + code="__out = __in0 + 9.0", + outputs={"__out": dace.Memlet("d[__i0, __i1]")}, + external_edges=True, + ) + + # Now apply the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert "tmp_1" in sdfg.arrays + assert "tmp_2" not in sdfg.arrays + assert sdfg.number_of_nodes() == 2 + assert _count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert _count_nodes(new_state, dace_nodes.MapEntry) == 1 + + a = np.random.rand(N, N) + b = np.empty_like(a) + c = np.empty_like(a) + d = np.empty_like(a) + ref_b = a + 4.0 + ref_c = a - 4.0 + ref_d = a + 10.0 + + sdfg(a=a, b=b, c=c, d=d) + assert np.allclose(ref_b, b) + assert np.allclose(ref_c, c) + assert np.allclose(ref_d, d) From 6d1275764122522dc680f817cbef52799b032337 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 14:13:50 +0200 Subject: [PATCH 184/235] Made some small modifications to the map fusion test. --- .../transformation_tests/conftest.py | 35 +++++++++++++++++++ .../transformation_tests/test_map_fusion.py | 10 +----- 2 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py new file mode 100644 index 0000000000..c0adf27a21 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py @@ -0,0 +1,35 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, Union, overload, Literal, Generator + +import pytest +import dace +import copy +import numpy as np +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +@pytest.fixture(autouse=True) +def _set_dace_settings() -> Generator[None, None, None]: + """Enables the correct settings in DaCe.""" + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py index bd0a7fe066..e24b30af15 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, Sequence, Union, overload, Literal, Generator +from typing import Any, Optional, Sequence, Union, Literal, overload import pytest import dace @@ -27,14 +27,6 @@ ) -@pytest.fixture(autouse=True) -def _set_dace_settings() -> Generator[None, None, None]: - """Enables the correct settings in DaCe.""" - with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=True) - yield - - @overload def _count_nodes( graph: Union[dace.SDFG, dace.SDFGState], From e590a07400eadc3a7191ae60deaddd2fcb9cf6a1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 14:14:51 +0200 Subject: [PATCH 185/235] Added a test for the blocking. --- .../transformation_tests/test_k_blocking.py | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py new file mode 100644 index 0000000000..e576cea76f --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py @@ -0,0 +1,150 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Callable +import dace +import copy +import numpy as np + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np.ndarray]]: + """Creates a simple SDFG. + + The k blocking transformation can be applied to the SDFG, however no node + can be taken out. This is because how it is constructed. However, applying + some simplistic transformations this can be done. + """ + sdfg = dace.SDFG("only_dependent") + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + _, a = sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + _, b = sdfg.add_array("b", ("N",), dace.float64, transient=False) + _, c = sdfg.add_array("c", ("N", "M"), dace.float64, transient=False) + state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:N", j=f"0:M"), + inputs=dict(__in0=dace.Memlet("a[i, j]"), __in1=dace.Memlet("b[i]")), + outputs=dict(__out=dace.Memlet("c[i, j]")), + code="__out = __in0 + __in1", + external_edges=True, + ) + return sdfg, lambda a, b: a + b.reshape((-1, 1)) + + +def test_only_dependent(): + """Just applying the transformation to the SDFG. + + Because all of nodes (which is only a Tasklet) inside the map scope are + "dependent", see the transformation for explanation of terminology, the + transformation will only add an inner map. + """ + sdfg, reff = _get_simple_sdfg() + + N, M = 100, 10 + a = np.random.rand(N, M) + b = np.random.rand(N) + c = np.zeros_like(a) + ref = reff(a, b) + + # Apply the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + validate=True, + validate_all=True, + ) + + assert len(sdfg.states()) == 1 + state = sdfg.states()[0] + source_nodes = state.source_nodes() + assert len(source_nodes) == 2 + assert all(isinstance(x, dace_nodes.AccessNode) for x in source_nodes) + source_node = source_nodes[0] # Unspecific which one it is, but it does not matter. + assert state.out_degree(source_node) == 1 + outer_map: dace_nodes.MapEntry = next(iter(state.out_edges(source_node))).dst + assert isinstance(outer_map, dace_nodes.MapEntry) + assert state.in_degree(outer_map) == 2 + assert state.out_degree(outer_map) == 2 + assert len(outer_map.map.params) == 2 + assert "j" not in outer_map.map.params + assert all(isinstance(x.dst, dace_nodes.MapEntry) for x in state.out_edges(outer_map)) + inner_map: dace_nodes.MapEntry = next(iter(state.out_edges(outer_map))).dst + assert len(inner_map.map.params) == 1 + assert inner_map.map.params[0] == "j" + assert inner_map.map.schedule == dace.dtypes.ScheduleType.Sequential + + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) + + +def test_intermediate_access_node(): + """Test the lifting out, version "AccessNode". + + The Tasklet of the SDFG generated by `_get_simple_sdfg()` has to be inside the + inner most loop because one of its input Memlet depends on `j`. However, + one of its input, `b[i]` does not. Instead of connecting `b` directly with the + Tasklet, this test will store `b[i]` inside a temporary inside the Map. + This access node is independent of `j` and can thus be moved out of the inner + most scope. + """ + sdfg, reff = _get_simple_sdfg() + + N, M = 100, 10 + a = np.random.rand(N, M) + b = np.random.rand(N) + c = np.zeros_like(a) + ref = reff(a, b) + + # Now make a small modification is such that the transformation does something. + state = sdfg.states()[0] + sdfg.add_scalar("tmp", dace.float64, transient=True) + + tmp = state.add_access("tmp") + edge = next( + e for e in state.edges() if isinstance(e.src, dace_nodes.MapEntry) and e.data.data == "b" + ) + state.add_edge(edge.src, edge.src_conn, tmp, None, copy.deepcopy(edge.data)) + state.add_edge(tmp, None, edge.dst, edge.dst_conn, dace.Memlet("tmp[0]")) + state.remove_edge(edge) + + # Test if after the modification the SDFG still works + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) + + # Apply the transformation. + sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + validate=True, + validate_all=True, + ) + + # Inspect if the SDFG was modified correctly. + # We only inspect `tmp` which now has to be between the two maps. + assert state.in_degree(tmp) == 1 + assert state.out_degree(tmp) == 1 + top_node = next(iter(state.in_edges(tmp))).src + bottom_node = next(iter(state.out_edges(tmp))).dst + assert isinstance(top_node, dace_nodes.MapEntry) + assert isinstance(bottom_node, dace_nodes.MapEntry) + assert bottom_node is not top_node + + c[:] = 0 + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) From 7e99d98e08dec0b64cd405f5e925c1c5725509cf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 15:09:40 +0200 Subject: [PATCH 186/235] Addressed Edoardo's comments. --- .../transformations/__init__.py | 2 +- .../transformations/auto_opt.py | 43 +++++++++---------- .../transformations/gpu_utils.py | 3 +- .../transformations/map_promoter.py | 2 +- ...seriall_fusion.py => map_serial_fusion.py} | 0 5 files changed, 25 insertions(+), 25 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/transformations/{map_seriall_fusion.py => map_serial_fusion.py} (100%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index e1172c2fce..2d8eb3fcf7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -28,7 +28,7 @@ from .k_blocking import KBlocking from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter -from .map_seriall_fusion import SerialMapFusion +from .map_serial_fusion import SerialMapFusion __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index d4f6efe73b..a390c5308a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -115,7 +115,7 @@ def gt_auto_optimize( inside the kernels itself. For example fuse maps inside them. 4. Afterwards it will process the map ranges and iteration order. For this the function assumes that the dimension indicated by `leading_dim` is the - once with stride one. + one with stride one. 5. If requested the function will now apply blocking, on the dimension indicated by `leading_dim`. (The reason that it is not done in the kernel optimization phase is a restriction dictated by the implementation.) @@ -127,10 +127,10 @@ def gt_auto_optimize( - Use fast implementation for library nodes. - Move small transients to stack. - Make transients persistent (if requested). - - If requested reuse transients. + - Apply DaCe's `TransientReuse` transformation (if requested). Args: - sdfg: The SDFG that should ve optimized in place. + sdfg: The SDFG that should be optimized in place. gpu: Optimize for GPU or CPU. leading_dim: Leading dimension, indicates where the stride is 1. aggressive_fusion: Be more aggressive in fusion, will lead to the promotion @@ -147,15 +147,15 @@ def gt_auto_optimize( validate_all: Perform extensive validation. Todo: - - Make sure that `SDFG.simplify()` is not called indirectly, by temporary + - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily overwriting it with `gt_simplify()`. - Specify arguments to set the size of GPU thread blocks depending on the dimensions. I.e. be able to use a different size for 1D than 2D Maps. - Add a parallel version of Map fusion. - - Implements some model to further guide to determine what we want to fuse. + - Implement some model to further guide to determine what we want to fuse. Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - - Create a custom array elimination pass that honor rule 1. + - Create a custom array elimination pass that honors rule 1. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU @@ -165,8 +165,7 @@ def gt_auto_optimize( # TODO(phimuell): Should there be a zeroth phase, in which we generate # a chanonical form of the SDFG, for example move all local maps - # to internal serial maps, such that they not block fusion? - # in the JaCe prototype we did that. + # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup gt_simplify(sdfg) @@ -180,15 +179,16 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Compute the SDFG to see if something has changed. + # Compute the SDFG hash to see if something has changed. sdfg_hash = sdfg.hash_sdfg() # Phase 2: Kernel Creation - # We will now try to reduce the number of kernels and create big one. - # For this we essentially use map fusion. We do this is a loop because + # We will now try to reduce the number of kernels and create large Maps/kernels. + # For this we essentially use Map fusion. We do this is a loop because # after a graph modification followed by simplify new fusing opportunities # might arise. We use the hash of the SDFG to detect if we have reached a # fix point. + # TODO(phimuell): Find a better upper bound for the starvation protection. for _ in range(100): # Use map fusion to reduce their number and to create big kernels # TODO(phimuell): Use a cost measurement to decide if fusion should be done. @@ -208,7 +208,7 @@ def gt_auto_optimize( phase2_cleanup.append(dace_dataflow.TrivialTaskletElimination()) # TODO(phimuell): Should we do this all the time or only once? (probably the later) - # TODO(phimuell): More control what we promote. + # TODO(phimuell): Add a criteria to decide if we should promote or not. phase2_cleanup.append( gtx_transformations.SerialMapPromoter( only_toplevel_maps=True, @@ -231,17 +231,15 @@ def gt_auto_optimize( if old_sdfg_hash == sdfg_hash: break - # The SDFG was modified by the transformations above. But no - # transformation could be applied, so we will now call simplify - # and start over. + # The SDFG was modified by the transformations above. The SDFG was + # modified. Call Simplify and try again to further optimize. gt_simplify(sdfg) else: - raise RuntimeWarning("Optimization of the SDFG did not converged.") + raise RuntimeWarning("Optimization of the SDFG did not converge.") # Phase 3: Optimizing the kernels themselves. - # Currently this is only applies fusion inside them. - # TODO(phimuell): Improve. + # Currently this only applies fusion inside Maps. sdfg.apply_transformations_repeated( gtx_transformations.SerialMapFusion( only_inner_maps=True, @@ -295,17 +293,18 @@ def gt_auto_optimize( # The DaCe auto optimizer also uses them. Note that the reuse transient # is not done by DaCe. if reuse_transients: - # TODO(phimuell): Investigate if we should enable it, it makes stuff - # harder for the compiler. Maybe write our own that + # TODO(phimuell): Investigate if we should enable it, it may make things + # harder for the compiler. Maybe write our own to # only consider big transients and not small ones (~60B) transient_reuse = dace.transformation.passes.TransientReuse() transient_reuse.apply_pass(sdfg, {}) + # Set the implementation of the library nodes. dace_aoptimize.set_fast_implementations(sdfg, device) - # TODO(phimuell): Fix the bug, it is used the tile value and not the stack array value. + # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) if make_persistent: - # TODO(phimuell): Allow to also make them `SDFG`. + # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. dace_aoptimize.make_transients_persistent(sdfg, device) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index c160fc52a4..5624eae4ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -101,6 +101,7 @@ def gt_gpu_transformation( # variables. Currently this is not a problem because of the way it is # implemented. # TODO(phimuell): Fix the issue described above. + # TODO(phimuell): Maybe we should fuse trivial GPU maps before we do anything. sdfg.apply_transformations_once_everywhere( gtx_transformations.SerialMapPromoterGPU(), validate=False, @@ -326,7 +327,7 @@ def can_be_applied( - If the top map is a trivial map. - If a valid partition exists that can be fused at all. """ - from .map_seriall_fusion import SerialMapFusion + from .map_serial_fusion import SerialMapFusion map_exit_1: nodes.MapExit = self.map_exit1 map_1: nodes.Map = map_exit_1.map diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index b0b5313b85..8def62c1c7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -315,7 +315,7 @@ def can_be_applied( permissive: bool = False, ) -> bool: """Tests if the Maps really can be fused.""" - from .map_seriall_fusion import SerialMapFusion + from .map_serial_fusion import SerialMapFusion if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_seriall_fusion.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py From bd35c6d86708a540cc9df829b433dd2a618a9a16 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 15:31:13 +0200 Subject: [PATCH 187/235] Forgot to apply some of Edoardo's suggestions. --- .../transformations/gpu_utils.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 5624eae4ac..3e251705ae 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -43,23 +43,23 @@ def gt_gpu_transformation( validate_all: bool = False, **kwargs: Any, ) -> dace.SDFG: - """Transform an SDFG into an GPU SDFG. + """Transform an SDFG into a GPU SDFG. The transformation expects a rather optimized SDFG and turn it into an SDFG capable of running on the GPU. The function performs the following steps: - - If requested, modify the storage location of the non transient arrays to - life on GPU. + - If requested, modify the storage location of the non transient arrays such + that they reside in GPU memory. - Call the normal GPU transform function followed by simplify. - If requested try to remove trivial kernels. - - If given set the `gpu_block_size` parameters of the Maps to the given value. + - If specified, set the `gpu_block_size` parameters of the Maps to the given value. Args: sdfg: The SDFG that should be processed. try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. use_gpu_storage: Assume that the non global memory is already on the GPU. - gpu_block_size: Set the GPU block size of all maps that does not have - one to this value. + gpu_block_size: Set to true when the SDFG array arguments are already allocated + on GPU global memory. This will avoid the data copy from host to GPU memory. Notes: The function might modify the order of the iteration variables of some @@ -78,7 +78,7 @@ def gt_gpu_transformation( # This way the GPU transformation will not create this copying stuff. if use_gpu_storage: for desc in sdfg.arrays.values(): - if not (desc.transient or not isinstance(desc, dace.data.Array)): + if isinstance(desc, dace.data.Array) and not desc.transient: desc.storage = dace.dtypes.StorageType.GPU_Global # Now turn it into a GPU SDFG @@ -91,9 +91,10 @@ def gt_gpu_transformation( gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # For reasons a Tasklet can not exist outside a Map in a GPU SDFG. The GPU - # transformation will thus adds trivial maps around them, which translate to - # a kernel launch. Our current solution is to promote them and then fuse it. + # Because of DaCe's design a Tasklet can not exist outside a Map in a GPU SDFG. + # The GPU transformation will thus add trivial maps around them, which + # translate to a kernel launch. Our current solution is to promote them and + # then fuse it. # NOTE: The current implementation has a flaw, because promotion and fusion # are two different steps, this is is inefficient. There are some problems # because the mapped Tasklet might not be fusable at all. However, the real From 0da8ae2d7485fe39a6d92bd2a67de9521d1ceb19 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 15:21:49 +0200 Subject: [PATCH 188/235] Added myself to the list of authors. --- AUTHORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS.md b/AUTHORS.md index 6c76e5759e..ad73e63978 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -22,6 +22,7 @@ - Madonna, Alberto. ETH Zurich - CSCS - Mariotti, Kean. ETH Zurich - CSCS - Müller, Christoph. MeteoSwiss +- Müller, Philip. ETH Zurich - CSCS - Osuna, Carlos. MeteoSwiss - Paone, Edoardo. ETH Zurich - CSCS - Röthlin, Matthias. MeteoSwiss From 888fb553f036686413378bd857163d2fb566c0a0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 16:22:40 +0200 Subject: [PATCH 189/235] Added an utility module. --- .../transformation_tests/test_map_fusion.py | 77 +++++-------------- .../transformation_tests/util.py | 65 ++++++++++++++++ 2 files changed, 84 insertions(+), 58 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py index e24b30af15..7e1e4a6eac 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py @@ -25,48 +25,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) - - -@overload -def _count_nodes( - graph: Union[dace.SDFG, dace.SDFGState], - node_type: tuple[type, ...] | type, - return_nodes: Literal[False], -) -> int: ... - - -@overload -def _count_nodes( - graph: Union[dace.SDFG, dace.SDFGState], - node_type: tuple[type, ...] | type, - return_nodes: Literal[True], -) -> list[dace_nodes.Node]: ... - - -def _count_nodes( - graph: Union[dace.SDFG, dace.SDFGState], - node_type: tuple[type, ...] | type, - return_nodes: bool = False, -) -> Union[int, list[dace_nodes.Node]]: - """Counts the number of nodes in of a particular type in `graph`. - - If `graph` is an SDFGState then only count the nodes inside this state, - but if `graph` is an SDFG count in all states. - - Args: - graph: The graph to scan. - node_type: The type or sequence of types of nodes to look for. - """ - - states = graph.states() if isinstance(graph, dace.SDFG) else [graph] - found_nodes: list[dace_nodes.Node] = [] - for state_nodes in states: - for node in state_nodes.nodes(): - if isinstance(node, node_type): - found_nodes.append(node) - if return_nodes: - return found_nodes - return len(found_nodes) +from . import util def _make_serial_sdfg_1( @@ -197,19 +156,19 @@ def test_exclusive_itermediate(): sdfg = _make_serial_sdfg_1(N) # Now apply the optimizations. - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( gtx_transformations.SerialMapFusion(), validate=True, validate_all=True, ) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" not in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b"] ] assert len(intermediate_nodes) == 1 @@ -234,19 +193,19 @@ def test_shared_itermediate(): sdfg.arrays["tmp"].transient = False # Now apply the optimizations. - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( gtx_transformations.SerialMapFusion(), validate=True, validate_all=True, ) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b", "tmp"] ] assert len(intermediate_nodes) == 1 @@ -268,7 +227,7 @@ def test_pure_output_node(): """Tests the path of a pure intermediate.""" N = 10 sdfg = _make_serial_sdfg_2(N) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 # The first fusion will only bring it down to two maps. sdfg.apply_transformations( @@ -276,13 +235,13 @@ def test_pure_output_node(): validate=True, validate_all=True, ) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( gtx_transformations.SerialMapFusion(), validate=True, validate_all=True, ) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -304,9 +263,9 @@ def test_array_intermediate(): """ N = 10 sdfg = _make_serial_sdfg_1(N) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 4 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 # Now perform the fusion sdfg.apply_transformations( @@ -314,7 +273,7 @@ def test_array_intermediate(): validate=True, validate_all=True, ) - map_entries = _count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) scope = next(iter(sdfg.states())).scope_dict() assert len(map_entries) == 3 @@ -325,7 +284,9 @@ def test_array_intermediate(): # Find the access node that is the new intermediate node. inner_access_nodes: list[dace_nodes.AccessNode] = [ - node for node in _count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None + node + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + if scope[node] is not None ] assert len(inner_access_nodes) == 1 inner_access_node = inner_access_nodes[0] @@ -349,7 +310,7 @@ def test_interstate_transient(): """ N = 10 sdfg = _make_serial_sdfg_2(N) - assert _count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 assert sdfg.number_of_nodes() == 1 # Now add the new state and the new output. @@ -375,8 +336,8 @@ def test_interstate_transient(): assert "tmp_1" in sdfg.arrays assert "tmp_2" not in sdfg.arrays assert sdfg.number_of_nodes() == 2 - assert _count_nodes(head_state, dace_nodes.MapEntry) == 1 - assert _count_nodes(new_state, dace_nodes.MapEntry) == 1 + assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py new file mode 100644 index 0000000000..bcc730091a --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py @@ -0,0 +1,65 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Union, Literal, overload + +import dace +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +__all__ = [ + "_count_nodes", +] + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[False], +) -> int: ... + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[True], +) -> list[dace_nodes.Node]: ... + + +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: bool = False, +) -> Union[int, list[dace_nodes.Node]]: + """Counts the number of nodes in of a particular type in `graph`. + + If `graph` is an SDFGState then only count the nodes inside this state, + but if `graph` is an SDFG count in all states. + + Args: + graph: The graph to scan. + node_type: The type or sequence of types of nodes to look for. + """ + + states = graph.states() if isinstance(graph, dace.SDFG) else [graph] + found_nodes: list[dace_nodes.Node] = [] + for state_nodes in states: + for node in state_nodes.nodes(): + if isinstance(node, node_type): + found_nodes.append(node) + if return_nodes: + return found_nodes + return len(found_nodes) From 0767d6f5cb40ddbd5881849197520e81698481c7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 16:23:05 +0200 Subject: [PATCH 190/235] Made it possible to extend the applicability of teh map promotion transformation. --- .../transformations/map_promoter.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 8def62c1c7..9331b983ed 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -84,6 +84,11 @@ class BaseMapPromoter(transformation.SingleStateTransformation): default=False, desc="If `True` promote horizontal dimensions.", ) + promote_all = properties.Property( + dtype=bool, + default=False, + desc="If `True` perform any promotion. Takes precedence over all other selectors.", + ) def map_to_promote( self, @@ -112,6 +117,7 @@ def __init__( promote_local: Optional[bool] = None, promote_vertical: Optional[bool] = None, promote_horizontal: Optional[bool] = None, + promote_all: Optional[bool] = None, *args: Any, **kwargs: Any, ) -> None: @@ -126,9 +132,19 @@ def __init__( self.promote_vertical = bool(promote_vertical) if promote_horizontal is not None: self.promote_horizontal = bool(promote_horizontal) + if promote_all is not None: + self.promote_all = bool(promote_all) + self.promote_horizontal = False + self.promote_vertical = False + self.promote_local = False if only_inner_maps and only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") - if not (self.promote_local or self.promote_vertical or self.promote_horizontal): + if not ( + self.promote_local + or self.promote_vertical + or self.promote_horizontal + or self.promote_all + ): raise ValueError( "You must select at least one class of dimension that should be promoted." ) @@ -177,21 +193,22 @@ def can_be_applied( # We now know which dimensions we have to add to the promotee map. # Now we must test if we are also allowed to make that promotion in the first place. - dimension_identifier: list[str] = [] - if self.promote_local: - dimension_identifier.append("__gtx_localdim") - if self.promote_vertical: - dimension_identifier.append("__gtx_vertical") - if self.promote_horizontal: - dimension_identifier.append("__gtx_horizontal") - if not dimension_identifier: - return False - for missing_map_param in missing_map_parameters: - if not any( - missing_map_param.endswith(dim_identifier) - for dim_identifier in dimension_identifier - ): + if not self.promote_all: + dimension_identifier: list[str] = [] + if self.promote_local: + dimension_identifier.append("__gtx_localdim") + if self.promote_vertical: + dimension_identifier.append("__gtx_vertical") + if self.promote_horizontal: + dimension_identifier.append("__gtx_horizontal") + if not dimension_identifier: return False + for missing_map_param in missing_map_parameters: + if not any( + missing_map_param.endswith(dim_identifier) + for dim_identifier in dimension_identifier + ): + return False return True From c396200f64eb9eabcbbac2ee365a5a1bc16330f5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 29 Jul 2024 16:23:40 +0200 Subject: [PATCH 191/235] Added a test for the map promotion. --- .../test_serial_map_promoter.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py new file mode 100644 index 0000000000..886d4888b6 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py @@ -0,0 +1,95 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Callable +import dace +import copy +import numpy as np + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) +from . import util + + +def test_serial_map_promotion(): + """Tests the serial Map promotion transformation.""" + N = 10 + shape_1d = (N,) + shape_2d = (N, N) + sdfg = dace.SDFG("serial_promotable") + state = sdfg.add_state(is_start_block=True) + + # 1D Arrays + for name in ["a", "tmp"]: + sdfg.add_array( + name=name, + shape=shape_1d, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + # 2D Arrays + for name in ["b", "c"]: + sdfg.add_array( + name=name, + shape=shape_2d, + dtype=dace.float64, + transient=False, + ) + tmp = state.add_access("tmp") + + _, map_entry_1d, _ = state.add_mapped_tasklet( + name="one_d_map", + map_ranges=[("__i0", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0]")}, + code="__out = __in0 + 1.0", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + _, map_entry_2d, _ = state.add_mapped_tasklet( + name="two_d_map", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp": tmp}, + inputs={"__in0": dace.Memlet("tmp[__i0]"), "__in1": dace.Memlet("b[__i0, __i1]")}, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + external_edges=True, + ) + + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert len(map_entry_1d.map.params) == 1 + assert len(map_entry_2d.map.params) == 2 + + sdfg.view() + # Now apply the promotion + sdfg.apply_transformations( + gtx_transformations.SerialMapPromoter( + promote_all=True, + ), + validate=True, + validate_all=True, + ) + + sdfg.view() + + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert len(map_entry_1d.map.params) == 2 + assert len(map_entry_2d.map.params) == 2 + assert set(map_entry_1d.map.params) == set(map_entry_2d.map.params) From 8e471f197501ea6c838cfc58ef4a93ce59b42f40 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 30 Jul 2024 09:57:47 +0200 Subject: [PATCH 192/235] Reorganized the tests. --- .../runners_tests/{ => dace}/transformation_tests/__init__.py | 0 .../runners_tests/{ => dace}/transformation_tests/conftest.py | 3 ++- .../runners_tests/dace/transformation_tests/dace_fieldview | 1 + .../{ => dace}/transformation_tests/test_k_blocking.py | 0 .../{ => dace}/transformation_tests/test_map_fusion.py | 0 .../transformation_tests/test_serial_map_promoter.py | 0 .../runners_tests/{ => dace}/transformation_tests/util.py | 0 7 files changed, 3 insertions(+), 1 deletion(-) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/__init__.py (100%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/conftest.py (90%) create mode 120000 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/test_k_blocking.py (100%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/test_map_fusion.py (100%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/test_serial_map_promoter.py (100%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/{ => dace}/transformation_tests/util.py (100%) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py similarity index 100% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/__init__.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py similarity index 90% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py index c0adf27a21..9beab4df46 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -29,7 +29,8 @@ @pytest.fixture(autouse=True) def _set_dace_settings() -> Generator[None, None, None]: - """Enables the correct settings in DaCe.""" + """Customizes DaCe settings during the tests.""" with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) + dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview new file mode 120000 index 0000000000..ef394e0d58 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview @@ -0,0 +1 @@ +/home/quint_essent/git/1_CSCS/gt4py/src/gt4py/next/program_processors/runners/dace_fieldview \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py similarity index 100% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_k_blocking.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py similarity index 100% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_map_fusion.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py similarity index 100% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/test_serial_map_promoter.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py similarity index 100% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/transformation_tests/util.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py From 28fcb840980ccee5db0b0818c99459530415c2cf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 30 Jul 2024 09:58:05 +0200 Subject: [PATCH 193/235] Modified teh first step of teh auto optimizer. I now use reapeated, because the two transformations could themsleves introduces trivial maps. --- .../runners/dace_fieldview/transformations/auto_opt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index a390c5308a..68bac85d9e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -169,9 +169,10 @@ def gt_auto_optimize( # Phase 1: Initial Cleanup gt_simplify(sdfg) - sdfg.apply_transformations_once_everywhere( + sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, + # TODO(phimuell): Investigate if these two are appropriate. dace_dataflow.MapReduceFusion, dace_dataflow.MapWCRFusion, ], From e8829c6281f8c69ddeb4bae34701099168cf4a12 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 30 Jul 2024 13:43:38 +0200 Subject: [PATCH 194/235] Added a test to ensure that fusion does not skrew up with indirect access Tasklets. --- .../transformation_tests/test_map_fusion.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index 7e1e4a6eac..a89d144e88 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -150,6 +150,73 @@ def _make_serial_sdfg_2( return sdfg +def _make_serial_sdfg_3( + N_input: str | int, + N_output: str | int, +) -> dace.SDFG: + """Creates a serial SDFG that has an indirect access Tasklet in the second map. + + The SDFG has three inputs `a`, `b` and `idx`. The first two are 1 dimensional + arrays, and the second is am array containing integers. + The top map computes `a + b` and stores that in `tmp`. + The second map then uses the elements of `idx` to make indirect accesses into + `tmp`, which are stored inside `c`. + + Args: + N_input: The length of `a` and `b`. + N_output: The length of `c` and `idx`. + """ + input_shape = (N_input,) + output_shape = (N_output,) + + sdfg = dace.SDFG("serial_3_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name, shape in [ + ("a", input_shape), + ("b", input_shape), + ("c", output_shape), + ("idx", output_shape), + ("tmp", input_shape), + ]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + tmp = state.add_access("tmp") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N_input}")], + inputs={ + "__in0": dace.Memlet("a[__i0]"), + "__in1": dace.Memlet("b[__i0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="indirect_access", + map_ranges=[("__i0", f"0:{N_output}")], + input_nodes={"tmp": tmp}, + inputs={ + "__index": dace.Memlet("idx[__i0]"), + "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), + }, + code="__out = __array[__index]", + outputs={"__out": dace.Memlet("c[__i0]")}, + external_edges=True, + ) + + return sdfg + + def test_exclusive_itermediate(): """Tests if the exclusive intermediate branch works.""" N = 10 @@ -351,3 +418,40 @@ def test_interstate_transient(): assert np.allclose(ref_b, b) assert np.allclose(ref_c, c) assert np.allclose(ref_d, d) + + +def test_indirect_access(): + """Tests if indirect accesses are handled. + + Indirect accesses, a Tasklet dereferences the array, can not be fused, because + the array is accessed by the Tasklet. + """ + N_input = 100 + N_output = 1000 + a = np.random.rand(N_input) + b = np.random.rand(N_input) + c = np.empty(N_output) + idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) + sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + def _ref(a, b, idx): + tmp = a + b + return tmp[idx] + + ref = _ref(a, b, idx) + + sdfg(a=a, b=b, idx=idx, c=c) + assert np.allclose(ref, c) + + # Now "apply" the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + c[:] = -1.0 + sdfg(a=a, b=b, idx=idx, c=c) + assert np.allclose(ref, c) From 9373629ddde081da3b0518c1be1fbcae95fe8325 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 30 Jul 2024 13:52:29 +0200 Subject: [PATCH 195/235] Added a todo for a test. --- .../dace/transformation_tests/test_map_fusion.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index a89d144e88..e38032dd11 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -455,3 +455,8 @@ def _ref(a, b, idx): c[:] = -1.0 sdfg(a=a, b=b, idx=idx, c=c) assert np.allclose(ref, c) + + +def test_indirect_access_2(): + # TODO(phimuell): Index should be computed and that map should be fusable. + pass From 63a51126188ba0e4c460aa049b64ec8f5c3557c6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 30 Jul 2024 14:49:21 +0200 Subject: [PATCH 196/235] Addressed Edoardo's comment. --- .../runners/dace_fieldview/transformations/gpu_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 3e251705ae..350f696326 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -154,14 +154,14 @@ def _gpu_block_parser( self: "GPUSetBlockSize", val: Any, ) -> None: - """Used by the setter ob `GPUSetBlockSize.block_size`.""" + """Used by the setter of `GPUSetBlockSize.block_size`.""" org_val = val if isinstance(val, tuple): pass elif isinstance(val, list): val = tuple(val) elif isinstance(val, str): - val = tuple(x.replace(" ", "") for x in val.split(",")) + val = tuple(x.strip() for x in val.split(",")) else: raise TypeError( f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size." From 5a2c12ce1d4e96c7b0beb16b54f67c92ca049c31 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 07:50:22 +0200 Subject: [PATCH 197/235] Added the possibility to controll the iteration order also from teh outside. --- .../transformations/__init__.py | 17 ++++---- .../transformations/auto_opt.py | 43 +++++++++++++++++-- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 2d8eb3fcf7..c0f30f535f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -18,7 +18,7 @@ that explains the general structure and requirements on the SDFG. """ -from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_simplify +from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_set_iteration_order, gt_simplify from .gpu_utils import ( GPUSetBlockSize, SerialMapPromoterGPU, @@ -32,15 +32,16 @@ __all__ = [ + "GPUSetBlockSize", + "KBlocking", + "MapIterationOrder", + "SerialMapFusion", + "SerialMapPromoter", + "SerialMapPromoterGPU", "dace_auto_optimize", "gt_auto_optimize", - "gt_simplify", "gt_gpu_transformation", + "gt_set_iteration_order", "gt_set_gpu_blocksize", - "SerialMapFusion", - "SerialMapPromoter", - "SerialMapPromoterGPU", - "MapIterationOrder", - "GPUSetBlockSize", - "KBlocking", + "gt_simplify", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 68bac85d9e..f13117f50e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -26,6 +26,14 @@ ) +__all__ = [ + "dace_auto_optimize", + "gt_simplify", + "gt_set_iteration_order", + "gt_auto_optimize", +] + + def dace_auto_optimize( sdfg: dace.SDFG, device: dace.DeviceType = dace.DeviceType.CPU, @@ -82,6 +90,32 @@ def gt_simplify( ).apply_pass(sdfg, {}) +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: gtx_common.Dimension, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + """ + return sdfg.apply_transformations_once_everywhere( + gtx_transformations.MapIterationOrder( + leading_dim=leading_dim, + ) + ) + + def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool, @@ -254,10 +288,11 @@ def gt_auto_optimize( # This essentially ensures that the stride 1 dimensions are handled # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: - sdfg.apply_transformations_once_everywhere( - gtx_transformations.MapIterationOrder( - leading_dim=leading_dim, - ) + gt_set_iteration_order( + sdfg=sdfg, + leading_dim=leading_dim, + validate=validate, + validate_all=validate_all, ) # Phase 5: Apply blocking From bcfbd68c0b0aec85ba3a71ae3c7fd4c6ad25ad15 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 09:11:59 +0200 Subject: [PATCH 198/235] Clarified some buggy behaviour inside the GPU transformation function. --- .../transformations/gpu_utils.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 350f696326..b528d25316 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -63,7 +63,9 @@ def gt_gpu_transformation( Notes: The function might modify the order of the iteration variables of some - maps and fuse other Maps. + maps. + In addition it might fuse Maps together that should not be fused. To prevent + that you should set `try_removing_trivial_maps` to `False`. Todo: - Solve the fusing problem. @@ -71,8 +73,11 @@ def gt_gpu_transformation( """ # You need guru level or above to use these arguments. - gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) - gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) + gpu_launch_factor: Optional[int] = kwargs.pop("gpu_launch_factor", None) + gpu_launch_bounds: Optional[int] = kwargs.pop("gpu_launch_bounds", None) + assert ( + len(kwargs) == 0 + ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" # Turn all global arrays (which we identify as input) into GPU memory. # This way the GPU transformation will not create this copying stuff. @@ -91,18 +96,15 @@ def gt_gpu_transformation( gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # Because of DaCe's design a Tasklet can not exist outside a Map in a GPU SDFG. - # The GPU transformation will thus add trivial maps around them, which - # translate to a kernel launch. Our current solution is to promote them and - # then fuse it. - # NOTE: The current implementation has a flaw, because promotion and fusion - # are two different steps, this is is inefficient. There are some problems - # because the mapped Tasklet might not be fusable at all. However, the real - # problem is, that Map fusion does not guarantee a certain order of Map - # variables. Currently this is not a problem because of the way it is - # implemented. + # A Tasklet, outside of a Map, that writes into an array on GPU can not work + # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet + # would write into a Scalar that then goes into a GPU Map, nothing would + # happen. So we might end up with lot of these trivial Maps, that results + # in a single kernel launch. To prevent this we will try to fuse them. + # NOTE: The current implementation has a bug, because promotion and fusion + # are two different steps. Because of this the function will implicitly + # fuse everything together it can find. # TODO(phimuell): Fix the issue described above. - # TODO(phimuell): Maybe we should fuse trivial GPU maps before we do anything. sdfg.apply_transformations_once_everywhere( gtx_transformations.SerialMapPromoterGPU(), validate=False, From 03f4b1ad310c8594c0ccf263221886121decfd9b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 09:15:59 +0200 Subject: [PATCH 199/235] Inside a Map there can not be a library node for fusion. --- .../dace_fieldview/transformations/map_fusion_helper.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index bd6328c396..0ab9259292 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -477,7 +477,7 @@ def partition_first_outputs( # Certain nodes need more than one element as input. As explained # above, in this situation we assume that we can naturally decompose # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range. + # Furthermore, it can not be a dynamic map range or a library node. intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, feed_edge in consumers: @@ -488,6 +488,9 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range. return None + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. + return None # Note that "remove" has a special meaning here, regardless of the # output of the check function, from within the second map we remove @@ -520,6 +523,9 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range return None + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. + return None else: # Ensure that there is no path that leads to the second map. after_intermdiate_node = util.all_nodes_between( From fd2366fd0e1a73d568f59570baa6791dac4b7838 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 14:01:18 +0200 Subject: [PATCH 200/235] Applied Edoardo's comments. --- .../transformations/auto_opt.py | 1 + .../transformations/k_blocking.py | 4 +- .../transformations/map_promoter.py | 10 ++-- .../transformations/map_serial_fusion.py | 56 +++++++++---------- .../dace_fieldview/transformations/util.py | 55 +++++++++--------- 5 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index f13117f50e..6c249340f9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -190,6 +190,7 @@ def gt_auto_optimize( Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - Create a custom array elimination pass that honors rule 1. + - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 2227c82729..165a3acafd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -33,7 +33,7 @@ class KBlocking(transformation.SingleStateTransformation): dimension, that is commonly called "k", but identified with `block_dim`. All dimensions except `k` are unaffected by this transformation. In the outer - Map the will replace the `k` range, currently `k = 0:N`, with + Map will be replace the `k` range, currently `k = 0:N`, with `__coarse_k = 0:N:B`, where `N` is the original size of the range and `B` is the block size, passed as `blocking_size`. The transformation also handles the case if `N % B != 0`. @@ -231,7 +231,7 @@ def apply( # of the node in one go. relocated_nodes.add(edge_dst) - # In order to be useful we have to temporary store the data the + # In order to be useful we have to temporarily store the data the # independent node generates assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift if isinstance(edge_dst, nodes.AccessNode): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 9331b983ed..2f7aff7d9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -36,7 +36,7 @@ class BaseMapPromoter(transformation.SingleStateTransformation): The transformation operates on two Maps, first the "source map". This map describes the Map that should be used as template. The second one is "map to promote". After the transformation the "map to promote" will have the same - map parameter than the "source map" has. + map parameter as the "source map" has. In order to properly work, the parameters of "source map" must be a strict superset of the ones of "map to promote". Furthermore, this transformation @@ -52,6 +52,8 @@ class BaseMapPromoter(transformation.SingleStateTransformation): promote_vertical: If `True` promote vertical dimensions; `True` by default. promote_local: If `True` promote local dimensions; `True` by default. promote_horizontal: If `True` promote horizontal dimensions; `False` by default. + promote_all: Do not impose any restriction on what to promote. The only + reasonable value is `True` or `None`. Note: This ignores tiling. @@ -311,9 +313,9 @@ class SerialMapPromoter(BaseMapPromoter): def expressions(cls) -> Any: """Get the match expressions. - The function generates two different match expression. The first match - describes the case where the top map must be promoted, while the second - case is the second/lower map must be promoted. + The function generates two match expressions. The first match describes + the case where the top map must be promoted, while the second case is + the second/lower map must be promoted. """ return [ dace.sdfg.utils.node_path_graph( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index 66be18e6c0..a17bcf4bd1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -35,12 +35,12 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): Things that are improved, compared to the native DaCe implementation: - Nested Maps. - - Temporary arrays and the correct propagation of their Memelts. + - Temporary arrays and the correct propagation of their Memlets. - Top Maps that have multiple outputs. Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewriting the connections - appropriate. + and the entry of the lower or second map and then rewrites the connections + appropriately. This transformation assumes that an SDFG obeys the structure that is outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that @@ -70,9 +70,9 @@ def expressions(cls) -> Any: The transformation matches the exit node of the top Map that is connected to an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on these nodes, - but more or less anything that has an outgoing connection of the first Map, - and is connected to the second map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. """ return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] @@ -86,7 +86,7 @@ def can_be_applied( """Tests if the matched Maps can be merged. The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constrains. + - The `can_be_fused()` of the base succeed, which checks some basic constraints. - The decomposition exists and at least one of the intermediate sets is not empty. """ @@ -102,7 +102,8 @@ def can_be_applied( return False # Two maps can be serially fused if the node decomposition exists and - # there at least one of the intermediate output sets is not empty. + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, @@ -111,7 +112,8 @@ def can_be_applied( ) if output_partition is None: return False - if not (output_partition[1] or output_partition[2]): + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): return False return True @@ -216,12 +218,12 @@ def handle_intermediate_set( state: The state in which the map is processed. sdfg: The SDFG that should be optimized. map_exit_1: The exit of the first/top map. - map_entry_1: The entry of the second map. + map_entry_2: The entry of the second map. map_exit_2: The exit of the second map. is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. Notes: - Before the transformation the `state` does not be to be valid and + Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. Todo: @@ -303,7 +305,7 @@ def handle_intermediate_set( ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) - # New we will reroute the output Memlet, thus it will no longer going + # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. # we will delete the previous edge later. pre_exit_memlet: dace.Memlet = pre_exit_edge.data @@ -314,11 +316,11 @@ def handle_intermediate_set( assert pre_exit_memlet.data == inter_name new_pre_exit_memlet.data = new_inter_name - # Now we have to fix the subset of the Memlet. - # Before the subset of the Memlet dependent on the Map variables, + # Now we have to modify the subset of the Memlet. + # Before the subset of the Memlet was dependent on the Map variables, # however, this is no longer the case, as we removed them. This change # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the bellow is correct. + # NOTE: Assert above ensures that the below is correct. new_pre_exit_memlet.replace(memlet_repl) if is_scalar: new_pre_exit_memlet.subset = "0" @@ -350,14 +352,14 @@ def handle_intermediate_set( producer_edge.data.replace(memlet_repl) if is_scalar: producer_edge.data.dst_subset = "0" - else: - if producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) + elif producer_edge.data.dst_subset is not None: + producer_edge.data.dst_subset.pop(squeezed_dims) # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the temporary - # in the second map. We do this by finding the input connectors on the - # map entry, such that we know where we have to reroute inside the Map. + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): @@ -386,7 +388,7 @@ def handle_intermediate_set( # Memlet and the correctness of the code below. new_inner_memlet = copy.deepcopy(inner_edge.data) new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not chenge the direction. + new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. # Now remove the old edge, that started the second map entry. # Also add the new edge that started at the new intermediate. @@ -402,9 +404,8 @@ def handle_intermediate_set( # Now we do subset modification to ensure that nothing failed. if is_scalar: new_inner_memlet.src_subset = "0" - else: - if new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.pop(squeezed_dims) # Now clean the Memlets of that tree to use the new intermediate node. for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): @@ -413,9 +414,8 @@ def handle_intermediate_set( consumer_edge.data.data = new_inter_name if is_scalar: consumer_edge.data.src_subset = "0" - else: - if consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) + elif consumer_edge.data.subset is not None: + consumer_edge.data.subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. # We will now delete the edges that brought the data. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 9e4f09c722..897aaeecab 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -14,17 +14,19 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Iterable +from typing import Iterable, Union import dace -from dace.sdfg import graph as dace_graph, nodes +from dace.sdfg import graph as dace_graph, nodes as dace_nodes -def is_nested_sdfg(sdfg: dace.SDFG) -> bool: - """Tests if `sdfg` is a neseted sdfg.""" +def is_nested_sdfg( + sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], +) -> bool: + """Tests if `sdfg` is a NestedSDFG.""" if isinstance(sdfg, dace.SDFGState): sdfg = sdfg.parent - if isinstance(sdfg, dace.nodes.NestedSDFG): + if isinstance(sdfg, dace_nodes.NestedSDFG): return True elif isinstance(sdfg, dace.SDFG): if sdfg.parent_nsdfg_node is not None: @@ -36,17 +38,17 @@ def is_nested_sdfg(sdfg: dace.SDFG) -> bool: def all_nodes_between( graph: dace.SDFG | dace.SDFGState, - begin: nodes.Node, - end: nodes.Node, + begin: dace_nodes.Node, + end: dace_nodes.Node, reverse: bool = False, -) -> set[nodes.Node] | None: +) -> set[dace_nodes.Node] | None: """Find all nodes that are reachable from `begin` but bound by `end`. - Essentially the function starts a DFS at `begin`, which is never part of the - returned set, if at a node an edge is found that lead to `end`, the function - will ignore this edge. However, it will find every node that is reachable - from `begin` that is reachable by a path that does not visit `end`. - In case `end` is never found the function will return `None`. + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end`, this edge is ignored. It will thus found any node that is reachable + from `begin` by a path that does not involve `end`. The returned set will + never contain `end` nor `begin`. In case `end` is never found the function + will return `None`. If `reverse` is set to `True` the function will start exploring at `end` and follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. @@ -58,12 +60,11 @@ def all_nodes_between( reverse: Perform a backward DFS. Notes: - - The returned set will never contain the node `begin`. - The returned set will also contain the nodes of path that starts at `begin` and ends at a node that is not `end`. """ - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: if reverse: return (edge.src for edge in graph.in_edges(node)) return (edge.dst for edge in graph.out_edges(node)) @@ -71,12 +72,12 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: if reverse: begin, end = end, begin - to_visit: list[nodes.Node] = [begin] - seen: set[nodes.Node] = set() + to_visit: list[dace_nodes.Node] = [begin] + seen: set[dace_nodes.Node] = set() found_end: bool = False while len(to_visit) > 0: - n: nodes.Node = to_visit.pop() + n: dace_nodes.Node = to_visit.pop() if n == end: found_end = True continue @@ -94,10 +95,10 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: def find_downstream_consumers( state: dace.SDFGState, - begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, reverse: bool = False, -) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Find all downstream connectors of `begin`. A consumer, in for this function, is any node that is neither an entry nor @@ -123,17 +124,17 @@ def find_downstream_consumers( else: to_visit = list(state.out_edges(begin)) seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() while len(to_visit) != 0: curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst + next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst if curr_edge in seen: continue seen.add(curr_edge) - if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): + if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): if reverse: target_conn = curr_edge.src_conn[4:] new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) @@ -141,7 +142,7 @@ def find_downstream_consumers( # In forward mode a Map entry could also mean the definition of a # dynamic map range. if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, nodes.MapEntry + next_node, dace_nodes.MapEntry ): # This edge defines a dynamic map range, which is a consumer if not only_tasklets: @@ -152,7 +153,7 @@ def find_downstream_consumers( to_visit.extend(new_edges) del new_edges else: - if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): + if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): continue found.add((next_node, curr_edge)) @@ -161,9 +162,9 @@ def find_downstream_consumers( def find_upstream_producers( state: dace.SDFGState, - begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, -) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" return find_downstream_consumers( state=state, From 5ed2a8f55c0ec438b84b5ad63e1ec0d2f7a19633 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 14:06:05 +0200 Subject: [PATCH 201/235] Applied another change. --- .../runners/dace_fieldview/transformations/k_blocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 165a3acafd..94f49f9b40 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -93,7 +93,7 @@ def can_be_applied( - Toplevel map. - The map shall not be serial. - The block dimension must be present (exact match). - - The map range must have stride one. + - The map range must have step size of 1. - The partition must exists (see `partition_map_output()`). """ if self.block_dim is None: From 8e97cd67b9c910a2547fd8b1c03a379706cf6c25 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 2 Aug 2024 12:14:36 +0200 Subject: [PATCH 202/235] Removed stray symlink. --- .../runners_tests/dace/transformation_tests/dace_fieldview | 1 - 1 file changed, 1 deletion(-) delete mode 120000 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview deleted file mode 120000 index ef394e0d58..0000000000 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/dace_fieldview +++ /dev/null @@ -1 +0,0 @@ -/home/quint_essent/git/1_CSCS/gt4py/src/gt4py/next/program_processors/runners/dace_fieldview \ No newline at end of file From 46c549bcd0ece18b44ad3f32e45303ee0e076114 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Aug 2024 14:38:35 +0200 Subject: [PATCH 203/235] Updated the licence header. --- .../dace_fieldview/transformations/__init__.py | 12 +++--------- .../dace_fieldview/transformations/auto_opt.py | 12 +++--------- .../dace_fieldview/transformations/gpu_utils.py | 12 +++--------- .../dace_fieldview/transformations/k_blocking.py | 12 +++--------- .../transformations/map_fusion_helper.py | 12 +++--------- .../dace_fieldview/transformations/map_orderer.py | 12 +++--------- .../dace_fieldview/transformations/map_promoter.py | 12 +++--------- .../transformations/map_serial_fusion.py | 12 +++--------- .../runners/dace_fieldview/transformations/util.py | 12 +++--------- .../dace/transformation_tests/__init__.py | 12 +++--------- .../dace/transformation_tests/conftest.py | 12 +++--------- .../dace/transformation_tests/test_k_blocking.py | 12 +++--------- .../dace/transformation_tests/test_map_fusion.py | 12 +++--------- .../transformation_tests/test_serial_map_promoter.py | 12 +++--------- .../runners_tests/dace/transformation_tests/util.py | 12 +++--------- 15 files changed, 45 insertions(+), 135 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index c0f30f535f..4b80710c9b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Transformation and optimization pipeline for the DaCe backend in GT4Py. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 6c249340f9..50d003a93f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Fast access to the auto optimization on DaCe.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index b528d25316..1fecb2d342 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Functions for turning an SDFG into a GPU SDFG.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 94f49f9b40..a8b7ab7487 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause import copy import functools diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 0ab9259292..220a6edc1e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Implements Helper functionaliyies for map fusion""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index d42ff06edc..31fd32031e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Any, Optional, Sequence, Union diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 2f7aff7d9a..394aed2624 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Any, Mapping, Optional, Sequence, Union diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index a17bcf4bd1..f51f9b2ee4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Implements the seriall map fusing transformation.""" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 897aaeecab..f40749de23 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause """Common functionality for the transformations/optimization pipeline.""" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py index 67bee9d721..abf4c3e24c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py @@ -1,14 +1,8 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py index 9beab4df46..72e76a63e2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Any, Optional, Sequence, Union, overload, Literal, Generator diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py index e576cea76f..91d76ebd39 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Callable import dace diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index e38032dd11..8d8a108765 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Any, Optional, Sequence, Union, Literal, overload diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index 886d4888b6..5c9f555582 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Callable import dace diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py index bcc730091a..739582d5d9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py @@ -1,16 +1,10 @@ # GT4Py - GridTools Framework # -# Copyright (c) 2014-2023, ETH Zurich +# Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause from typing import Union, Literal, overload From 27d8ea62b7b08b7ee0d6d30e0f57cce931e7d8dc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Aug 2024 14:55:52 +0200 Subject: [PATCH 204/235] This should make the names a bit more consistent. --- .../transformations/gpu_utils.py | 34 +++--- .../transformations/k_blocking.py | 80 +++++++------ .../transformations/map_fusion_helper.py | 113 +++++++++++------- .../transformations/map_orderer.py | 20 ++-- .../transformations/map_promoter.py | 66 +++++----- .../transformations/map_serial_fusion.py | 60 ++++++---- 6 files changed, 208 insertions(+), 165 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 1fecb2d342..afd12a072c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -12,8 +12,8 @@ from typing import Any, Optional, Sequence, Union import dace -from dace import properties, transformation -from dace.sdfg import SDFG, SDFGState, nodes +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, @@ -186,8 +186,8 @@ def _gpu_block_getter( return tuple(self._block_size) -@properties.make_properties -class GPUSetBlockSize(transformation.SingleStateTransformation): +@dace_properties.make_properties +class GPUSetBlockSize(dace_transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. It is also possible to set the launch bound. @@ -202,7 +202,7 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. """ - block_size = properties.Property( + block_size = dace_properties.Property( dtype=None, allow_none=False, default=(32, 1, 1), @@ -211,14 +211,14 @@ class GPUSetBlockSize(transformation.SingleStateTransformation): desc="Size of the block size a GPU Map should have.", ) - launch_bounds = properties.Property( + launch_bounds = dace_properties.Property( dtype=str, allow_none=True, default=None, desc="Set the launch bound property of the map.", ) - map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -283,8 +283,8 @@ def apply( self.map_entry.map.gpu_launch_bounds = self.launch_bounds -@properties.make_properties -class SerialMapPromoterGPU(transformation.SingleStateTransformation): +@dace_properties.make_properties +class SerialMapPromoterGPU(dace_transformation.SingleStateTransformation): """Serial Map promoter for empty Maps in case of trivial Maps. In CPU mode a Tasklet can be outside of a map, however, this is not @@ -303,9 +303,9 @@ class SerialMapPromoterGPU(transformation.SingleStateTransformation): """ # Pattern Matching - map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) - access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -326,9 +326,9 @@ def can_be_applied( """ from .map_serial_fusion import SerialMapFusion - map_exit_1: nodes.MapExit = self.map_exit1 - map_1: nodes.Map = map_exit_1.map - map_entry_2: nodes.MapEntry = self.map_entry2 + map_exit_1: dace_nodes.MapExit = self.map_exit1 + map_1: dace_nodes.Map = map_exit_1.map + map_entry_2: dace_nodes.MapEntry = self.map_entry2 # Check if the first map is trivial. if len(map_1.params) != 1: @@ -363,8 +363,8 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: The function essentially copies the parameters and the ranges from the bottom map to the top one. """ - map_1: nodes.Map = self.map_exit1.map - map_2: nodes.Map = self.map_entry2.map + map_1: dace_nodes.Map = self.map_exit1.map + map_2: dace_nodes.Map = self.map_entry2.map map_1.params = copy.deepcopy(map_2.params) map_1.range = copy.deepcopy(map_2.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index a8b7ab7487..1e8ded1c1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -11,16 +11,20 @@ from typing import Any, Optional, Union import dace -from dace import properties, subsets, transformation -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes -from dace.transformation import helpers +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util +from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util -@properties.make_properties -class KBlocking(transformation.SingleStateTransformation): +@dace_properties.make_properties +class KBlocking(dace_transformation.SingleStateTransformation): """Applies k-Blocking with separation on a Map. This transformation takes a multidimensional Map and performs blocking on a @@ -42,18 +46,18 @@ class KBlocking(transformation.SingleStateTransformation): `_blocked` to it. """ - blocking_size = properties.Property( + blocking_size = dace_properties.Property( dtype=int, allow_none=True, desc="Size of the inner k Block.", ) - block_dim = properties.Property( + block_dim = dace_properties.Property( dtype=str, allow_none=True, desc="Which dimension should be blocked (must be an exact match).", ) - map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -64,7 +68,7 @@ def __init__( if isinstance(block_dim, str): pass elif isinstance(block_dim, gtx_common.Dimension): - block_dim = dace_fieldview_util.get_map_variable(block_dim) + block_dim = gtx_dace_fieldview_util.get_map_variable(block_dim) if block_dim is not None: self.block_dim = block_dim if blocking_size is not None: @@ -95,9 +99,9 @@ def can_be_applied( elif self.blocking_size is None: raise ValueError("The blocking size was not specified.") - map_entry: nodes.MapEntry = self.map_entry + map_entry: dace_nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params - map_range: subsets.Range = map_entry.map.range + map_range: dace_subsets.Range = map_entry.map.range block_var: str = self.block_dim scope = graph.scope_dict() @@ -123,10 +127,10 @@ def apply( Performs the operation described in the doc string. """ - outer_entry: nodes.MapEntry = self.map_entry - outer_exit: nodes.MapExit = graph.exit_node(outer_entry) - outer_map: nodes.Map = outer_entry.map - map_range: subsets.Range = outer_entry.map.range + outer_entry: dace_nodes.MapEntry = self.map_entry + outer_exit: dace_nodes.MapExit = graph.exit_node(outer_entry) + outer_map: dace_nodes.Map = outer_entry.map + map_range: dace_subsets.Range = outer_entry.map.range map_params: list[str] = outer_entry.map.params # This is the name of the iterator we coarsen @@ -147,7 +151,7 @@ def apply( rng_stop = map_range[block_idx][1] inner_label = f"inner_{outer_map.label}" inner_range = { - block_var: subsets.Range.from_string( + block_var: dace_subsets.Range.from_string( f"({coarse_block_var} * {self.blocking_size} + {rng_start}):min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" ) } @@ -160,7 +164,7 @@ def apply( # TODO(phimuell): Investigate if we want to prevent unrolling here # Now we modify the properties of the outer map. - coarse_block_range = subsets.Range.from_string( + coarse_block_range = dace_subsets.Range.from_string( f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" ).ranges[0] outer_map.params[block_idx] = coarse_block_var @@ -168,12 +172,12 @@ def apply( outer_map.label = f"{outer_map.label}_blocked" # Contains the independent nodes that are already relocated. - relocated_nodes: set[nodes.Node] = set() + relocated_nodes: set[dace_nodes.Node] = set() # Now we iterate over all the output edges of the outer map and rewire them. # Note that this only handles the entry of the Map. for out_edge in list(graph.out_edges(outer_entry)): - edge_dst: nodes.Node = out_edge.dst + edge_dst: dace_nodes.Node = out_edge.dst if edge_dst in dependent_nodes: # This is the simple case as we just have to rewire the edge @@ -183,7 +187,7 @@ def apply( # Must be before the handling of the modification below # Note that this will remove the original edge from the SDFG. - helpers.redirect_edge( + dace_helpers.redirect_edge( state=graph, edge=out_edge, new_src=inner_entry, @@ -228,22 +232,22 @@ def apply( # In order to be useful we have to temporarily store the data the # independent node generates assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift - if isinstance(edge_dst, nodes.AccessNode): + if isinstance(edge_dst, dace_nodes.AccessNode): # The independent node is an access node, so we can use it directly. - caching_node: nodes.AccessNode = edge_dst + caching_node: dace_nodes.AccessNode = edge_dst else: # The dependent node is not an access node. For now we will # just use the next node, with some restriction. # TODO(phimuell): create an access node in this case instead. caching_node = next(iter(graph.out_edges(edge_dst))).dst assert graph.in_degree(caching_node) == 1 - assert isinstance(caching_node, nodes.AccessNode) + assert isinstance(caching_node, dace_nodes.AccessNode) # Now rewire the Memlets that leave the caching node to go through # new inner Map. for consumer_edge in list(graph.out_edges(caching_node)): new_map_conn = inner_entry.next_connector() - helpers.redirect_edge( + dace_helpers.redirect_edge( state=graph, edge=consumer_edge, new_dst=inner_entry, @@ -265,7 +269,7 @@ def apply( # but we do not use them for now. for out_edge in list(graph.in_edges(outer_exit)): edge_conn = out_edge.dst_conn[3:] - helpers.redirect_edge( + dace_helpers.redirect_edge( state=graph, edge=out_edge, new_dst=inner_exit, @@ -286,11 +290,11 @@ def apply( def partition_map_output( self, - map_entry: nodes.MapEntry, + map_entry: dace_nodes.MapEntry, block_param: str, state: SDFGState, sdfg: SDFG, - ) -> tuple[set[nodes.Node], set[nodes.Node]] | None: + ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: """Partition the outputs of the Map. The partition will only look at the direct intermediate outputs of the @@ -316,22 +320,22 @@ def partition_map_output( `used_symbol` properties of a Tasklet. - Furthermore only the first level is inspected. """ - block_independent: set[nodes.Node] = set() # `\mathcal{I}` - block_dependent: set[nodes.Node] = set() # `\mathcal{D}` + block_independent: set[dace_nodes.Node] = set() # `\mathcal{I}` + block_dependent: set[dace_nodes.Node] = set() # `\mathcal{D}` # Find all nodes that are adjacent to the map entry. - nodes_to_partition: set[nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} + nodes_to_partition: set[dace_nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} # Now we examine every node and assign them to one of the sets. # Note that this is only tentative and we will later inspect the # outputs of the independent node and reevaluate their classification. for node in nodes_to_partition: # Filter out all nodes that we can not (yet) handle. - if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): + if not isinstance(node, (dace_nodes.Tasklet, dace_nodes.AccessNode)): return None # Check if we have a strange Tasklet or if it uses the symbol inside it. - if isinstance(node, nodes.Tasklet): + if isinstance(node, dace_nodes.Tasklet): if node.side_effects: return None if block_param in node.free_symbols: @@ -376,8 +380,8 @@ def partition_map_output( # clause then we classify the node as independent. for edge in edges: memlet: dace.Memlet = edge.data - src_subset: subsets.Subset | None = memlet.src_subset - dst_subset: subsets.Subset | None = memlet.dst_subset + src_subset: dace_subsets.Subset | None = memlet.src_subset + dst_subset: dace_subsets.Subset | None = memlet.dst_subset edge_desc: dace.data.Data = sdfg.arrays[memlet.data] edge_desc_size = functools.reduce(lambda a, b: a * b, edge_desc.shape) @@ -389,7 +393,7 @@ def partition_map_output( continue # Now we have to look at the source and destination set of the Memlet. - subsets_to_inspect: list[subsets.Subset] = [] + subsets_to_inspect: list[dace_subsets.Subset] = [] if dst_subset is not None: subsets_to_inspect.append(dst_subset) if src_subset is not None: @@ -412,14 +416,14 @@ def partition_map_output( # We now make a last screening of the independent nodes. # TODO(phimuell): Make an iterative process to find the maximal set. for independent_node in list(block_independent): - if isinstance(independent_node, nodes.AccessNode): + if isinstance(independent_node, dace_nodes.AccessNode): if state.in_degree(independent_node) != 1: block_independent.discard(independent_node) block_dependent.add(independent_node) continue for out_edge in state.out_edges(independent_node): if ( - (not isinstance(out_edge.dst, nodes.AccessNode)) + (not isinstance(out_edge.dst, dace_nodes.AccessNode)) or (state.in_degree(out_edge.dst) != 1) or (out_edge.dst.desc(sdfg).lifetime != dace.dtypes.AllocationLifetime.Scope) ): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 220a6edc1e..e8433f5cea 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -13,15 +13,26 @@ from typing import Any, Optional, Sequence, Union import dace -from dace import data, properties, subsets, transformation -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes -from dace.transformation import helpers +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import ( + SDFG, + SDFGState, + graph as dace_graph, + nodes as dace_nodes, + validation as dace_validation, +) +from dace.transformation import helpers as dace_helpers from gt4py.next.program_processors.runners.dace_fieldview.transformations import util -@properties.make_properties -class MapFusionHelper(transformation.SingleStateTransformation): +@dace_properties.make_properties +class MapFusionHelper(dace_transformation.SingleStateTransformation): """Contains common part of the fusion for parallel and serial Map fusion. The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). @@ -36,19 +47,19 @@ class MapFusionHelper(transformation.SingleStateTransformation): only_toplevel_maps: Only consider Maps that are at the top. """ - only_toplevel_maps = properties.Property( + only_toplevel_maps = dace_properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) - only_inner_maps = properties.Property( + only_inner_maps = dace_properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = properties.DictProperty( + shared_transients = dace_properties.DictProperty( key_type=SDFG, value_type=set[str], default=None, @@ -76,8 +87,8 @@ def expressions(cls) -> bool: def can_be_fused( self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, @@ -131,8 +142,8 @@ def can_be_fused( @staticmethod def relocate_nodes( - from_node: Union[nodes.MapExit, nodes.MapEntry], - to_node: Union[nodes.MapExit, nodes.MapEntry], + from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], state: SDFGState, sdfg: SDFG, ) -> None: @@ -141,6 +152,8 @@ def relocate_nodes( This function will only rewire the edges, it does not remove the nodes themselves. Furthermore, this function should be called twice per Map, once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. Args: from_node: Node from which the edges should be removed. @@ -151,13 +164,13 @@ def relocate_nodes( # Now we relocate empty Memlets, from the `from_node` to the `to_node` for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_src=to_node) + dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_dst=to_node) + dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. - empty_targets: set[nodes.Node] = set() + empty_targets: set[dace_nodes.Node] = set() for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) @@ -185,7 +198,7 @@ def relocate_nodes( raise RuntimeError( # Might fail because of out connectors. f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." ) - helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) # There is no other edge that we have to consider, so we just end here @@ -197,22 +210,35 @@ def relocate_nodes( to_node.add_in_connector("IN_" + new_conn) for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): - helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) to_node.add_out_connector("OUT_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + dace_helpers.redirect_edge( + state, e, new_src=to_node, new_src_conn="OUT_" + new_conn + ) from_node.remove_in_connector("IN_" + old_conn) from_node.remove_out_connector("OUT_" + old_conn) - assert state.in_degree(from_node) == 0 + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) assert len(from_node.in_connectors) == 0 - assert state.out_degree(from_node) == 0 assert len(from_node.out_connectors) == 0 @staticmethod def map_parameter_compatible( - map_1: nodes.Map, - map_2: nodes.Map, + map_1: dace_nodes.Map, + map_2: dace_nodes.Map, state: Union[SDFGState, SDFG], sdfg: SDFG, ) -> bool: @@ -223,9 +249,9 @@ def map_parameter_compatible( is performed. - The ranges must be the same. """ - range_1: subsets.Range = map_1.range + range_1: dace_subsets.Range = map_1.range params_1: Sequence[str] = map_1.params - range_2: subsets.Range = map_2.range + range_2: dace_subsets.Range = map_2.range params_2: Sequence[str] = map_2.params # The maps are only fuseable if we have an exact match in the parameter names @@ -250,7 +276,7 @@ def map_parameter_compatible( def is_interstate_transient( self, - transient: Union[str, nodes.AccessNode], + transient: Union[str, dace_nodes.AccessNode], sdfg: dace.SDFG, state: dace.SDFGState, ) -> bool: @@ -276,10 +302,10 @@ def is_interstate_transient( """ # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access nodes. + # the set of such transients is partially given by all source access dace_nodes. # Because of rule 3 we also include all scalars in this set, as an over # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink nodes. + # we also include the sink dace_nodes. # See if we have already computed the set if sdfg in self.shared_transients: @@ -295,7 +321,8 @@ def is_interstate_transient( for node in itertools.chain( state_to_scan.source_nodes(), state_to_scan.sink_nodes() ) - if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient + if isinstance(node, dace_nodes.AccessNode) + and sdfg.arrays[node.data].transient ] ) self.shared_transients[sdfg] = shared_sdfg_transients @@ -307,13 +334,13 @@ def is_interstate_transient( assert len(matching_access_nodes) == 1 transient = matching_access_nodes[0] else: - assert isinstance(transient, nodes.AccessNode) + assert isinstance(transient, dace_nodes.AccessNode) name = transient.data - desc: data.Data = sdfg.arrays[name] + desc: dace_data.Data = sdfg.arrays[name] if not desc.transient: return True - if isinstance(desc, data.Scalar): + if isinstance(desc, dace_data.Scalar): return True # Scalars can not be removed by fusion anyway. # Rule 8: If degree larger than one then it is used within the state. @@ -327,8 +354,8 @@ def partition_first_outputs( self, state: SDFGState, sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, ) -> Union[ tuple[ set[dace_graph.MultiConnectorEdge[dace.Memlet]], @@ -373,11 +400,11 @@ def partition_first_outputs( shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[nodes.Node] = set() + processed_inter_nodes: set[dace_nodes.Node] = set() # Now scan all output edges of the first exit and classify them for out_edge in state.out_edges(map_exit_1): - intermediate_node: nodes.Node = out_edge.dst + intermediate_node: dace_nodes.Node = out_edge.dst # We already processed the node, this should indicate that we should # run simplify again, or we should start implementing this case. @@ -431,13 +458,13 @@ def partition_first_outputs( # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be # handled has shared intermediates. - if not isinstance(intermediate_node, nodes.AccessNode): + if not isinstance(intermediate_node, dace_nodes.AccessNode): return None - intermediate_desc: data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, data.View): + intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) + if isinstance(intermediate_desc, dace_data.View): return None - # There are some restrictions we have on intermediate nodes. The first one + # There are some restrictions we have on intermediate dace_nodes. The first one # is that we do not allow WCR, this is because they need special handling # which is currently not implement (the DaCe transformation has this # restriction as well). The second one is that we can reduce the @@ -482,8 +509,8 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range. return None - if isinstance(consumer_node, nodes.LibraryNode): - # TODO(phimuell): Allow some library nodes. + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. return None # Note that "remove" has a special meaning here, regardless of the @@ -517,8 +544,8 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range return None - if isinstance(consumer_node, nodes.LibraryNode): - # TODO(phimuell): Allow some library nodes. + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. return None else: # Ensure that there is no path that leads to the second map. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index 31fd32031e..f7d447fdc6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -9,15 +9,15 @@ from typing import Any, Optional, Sequence, Union import dace -from dace import properties, transformation -from dace.sdfg import SDFG, SDFGState, nodes +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as dace_fieldview_util +from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util -@properties.make_properties -class MapIterationOrder(transformation.SingleStateTransformation): +@dace_properties.make_properties +class MapIterationOrder(dace_transformation.SingleStateTransformation): """Modify the order of the iteration variables. The iteration order, while irrelevant from an SDFG point of view, is highly @@ -42,13 +42,13 @@ class MapIterationOrder(transformation.SingleStateTransformation): - Maybe also process the parameters to bring them in a canonical order. """ - leading_dim = properties.Property( + leading_dim = dace_properties.Property( dtype=str, allow_none=True, desc="Dimension that should become the leading dimension.", ) - map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -58,7 +58,7 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) if isinstance(leading_dim, gtx_common.Dimension): - self.leading_dim = dace_fieldview_util.get_map_variable(leading_dim) + self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) elif leading_dim is not None: self.leading_dim = leading_dim @@ -81,7 +81,7 @@ def can_be_applied( if self.leading_dim is None: return False - map_entry: nodes.MapEntry = self.map_entry + map_entry: dace_nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params map_var: str = self.leading_dim @@ -102,7 +102,7 @@ def apply( `self.leading_dim` the last map variable (this is given by the structure of DaCe's code generator). """ - map_entry: nodes.MapEntry = self.map_entry + map_entry: dace_nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params map_var: str = self.leading_dim diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 394aed2624..2f0d4ee261 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -9,8 +9,12 @@ from typing import Any, Mapping, Optional, Sequence, Union import dace -from dace import properties, subsets, transformation -from dace.sdfg import SDFG, SDFGState, nodes +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes __all__ = [ @@ -18,8 +22,8 @@ ] -@properties.make_properties -class BaseMapPromoter(transformation.SingleStateTransformation): +@dace_properties.make_properties +class BaseMapPromoter(dace_transformation.SingleStateTransformation): """Base transformation to add certain missing dimension to a map. By adding certain dimension to a Map, it might became possible to use the Map @@ -53,34 +57,34 @@ class BaseMapPromoter(transformation.SingleStateTransformation): This ignores tiling. """ - only_toplevel_maps = properties.Property( + only_toplevel_maps = dace_properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are on the top level.", ) - only_inner_maps = properties.Property( + only_inner_maps = dace_properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - promote_vertical = properties.Property( + promote_vertical = dace_properties.Property( dtype=bool, default=True, desc="If `True` promote vertical dimensions.", ) - promote_local = properties.Property( + promote_local = dace_properties.Property( dtype=bool, default=True, desc="If `True` promote local dimensions.", ) - promote_horizontal = properties.Property( + promote_horizontal = dace_properties.Property( dtype=bool, default=False, desc="If `True` promote horizontal dimensions.", ) - promote_all = properties.Property( + promote_all = dace_properties.Property( dtype=bool, default=False, desc="If `True` perform any promotion. Takes precedence over all other selectors.", @@ -90,7 +94,7 @@ def map_to_promote( self, state: dace.SDFGState, sdfg: dace.SDFG, - ) -> nodes.MapEntry: + ) -> dace_nodes.MapEntry: """Returns the map entry that should be promoted.""" raise NotImplementedError(f"{type(self).__name__} must implement 'map_to_promote'.") @@ -98,7 +102,7 @@ def source_map( self, state: dace.SDFGState, sdfg: dace.SDFG, - ) -> nodes.MapEntry: + ) -> dace_nodes.MapEntry: """Returns the map entry that is used as source/template.""" raise NotImplementedError(f"{type(self).__name__} must implement 'source_map'.") @@ -161,10 +165,10 @@ def can_be_applied( - If the parameter of the second map are compatible with each other. - If a dimension would be promoted that should not. """ - map_to_promote_entry: nodes.MapEntry = self.map_to_promote(state=graph, sdfg=sdfg) - map_to_promote: nodes.Map = map_to_promote_entry.map - source_map_entry: nodes.MapEntry = self.source_map(state=graph, sdfg=sdfg) - source_map: nodes.Map = source_map_entry.map + map_to_promote_entry: dace_nodes.MapEntry = self.map_to_promote(state=graph, sdfg=sdfg) + map_to_promote: dace_nodes.Map = map_to_promote_entry.map + source_map_entry: dace_nodes.MapEntry = self.source_map(state=graph, sdfg=sdfg) + source_map: dace_nodes.Map = source_map_entry.map # Test the scope of the promotee. # Because of the nature of the transformation, it is not needed that the @@ -172,7 +176,7 @@ def can_be_applied( # to ensure that the symbols are the same and all. But this is guaranteed by # the nature of this transformation (single state). if self.only_inner_maps or self.only_toplevel_maps: - scopeDict: Mapping[nodes.Node, Union[nodes.Node, None]] = graph.scope_dict() + scopeDict: Mapping[dace_nodes.Node, Union[dace_nodes.Node, None]] = graph.scope_dict() if self.only_inner_maps and (scopeDict[map_to_promote_entry] is None): return False if self.only_toplevel_maps and (scopeDict[map_to_promote_entry] is not None): @@ -216,10 +220,10 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: from the source map. The order of the parameters the Map has after the promotion is unspecific. """ - map_to_promote: nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map - source_map: nodes.Map = self.source_map(state=graph, sdfg=sdfg).map + map_to_promote: dace_nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map + source_map: dace_nodes.Map = self.source_map(state=graph, sdfg=sdfg).map source_params: Sequence[str] = source_map.params - source_ranges: subsets.Range = source_map.range + source_ranges: dace_subsets.Range = source_map.range missing_params: Sequence[str] = self.missing_map_params( # type: ignore[assignment] # Will never be `None` map_to_promote=map_to_promote, @@ -240,13 +244,13 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: # Now update the map properties # This action will also remove the tiles - map_to_promote.range = subsets.Range(promoted_ranges) + map_to_promote.range = dace_subsets.Range(promoted_ranges) map_to_promote.params = promoted_params def missing_map_params( self, - map_to_promote: nodes.Map, - source_map: nodes.Map, + map_to_promote: dace_nodes.Map, + source_map: dace_nodes.Map, be_strict: bool = True, ) -> list[str] | None: """Returns the parameter that are missing in the map that should be promoted. @@ -270,8 +274,8 @@ def missing_map_params( if be_strict: # Check if the parameters that are already in the map to promote have # the same range as in the source map. - source_ranges: subsets.Range = source_map.range - curr_ranges: subsets.Range = map_to_promote.range + source_ranges: dace_subsets.Range = source_map.range + curr_ranges: dace_subsets.Range = map_to_promote.range curr_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(map_to_promote.params)} source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_map.params)} for param_to_check in curr_params_set: @@ -282,7 +286,7 @@ def missing_map_params( return list(source_params_set - curr_params_set) -@properties.make_properties +@dace_properties.make_properties class SerialMapPromoter(BaseMapPromoter): """Promote a map such that it can be fused serially. @@ -299,9 +303,9 @@ class SerialMapPromoter(BaseMapPromoter): """ # Pattern Matching - exit_first_map = transformation.transformation.PatternNode(nodes.MapExit) - access_node = transformation.transformation.PatternNode(nodes.AccessNode) - entry_second_map = transformation.transformation.PatternNode(nodes.MapEntry) + exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -351,7 +355,7 @@ def map_to_promote( self, state: dace.SDFGState, sdfg: dace.SDFG, - ) -> nodes.MapEntry: + ) -> dace_nodes.MapEntry: if self.expr_index == 0: # The first the top map will be promoted. return state.entry_node(self.exit_first_map) @@ -364,7 +368,7 @@ def source_map( self, state: dace.SDFGState, sdfg: dace.SDFG, - ) -> nodes.MapEntry: + ) -> dace_nodes.MapEntry: """Returns the map entry that is used as source/template.""" if self.expr_index == 0: # The first the top map will be promoted, so the second map is the source. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index f51f9b2ee4..7cbf59813e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -6,19 +6,25 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements the seriall map fusing transformation.""" +"""Implements the serial map fusing transformation.""" import copy from typing import Any, Union import dace -from dace import dtypes, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + subsets as dace_subsets, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes -from . import map_fusion_helper +from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper -@properties.make_properties +@dace_properties.make_properties class SerialMapFusion(map_fusion_helper.MapFusionHelper): """Specialized replacement for the map fusion transformation that is provided by DaCe. @@ -48,9 +54,9 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): - This transformation modifies more nodes than it matches! """ - map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) - access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -84,10 +90,10 @@ def can_be_applied( - The decomposition exists and at least one of the intermediate sets is not empty. """ - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: nodes.MapEntry = self.map_entry2 + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + map_entry_2: dace_nodes.MapEntry = self.map_entry2 # This essentially test the structural properties of the two Maps. if not self.can_be_fused( @@ -128,12 +134,14 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # once we start adding and removing nodes it seems that their ID changes. # Thus we have to save them here, this is a known behaviour in DaCe. assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) - map_exit_1: nodes.MapExit = self.map_exit1 - map_entry_2: nodes.MapEntry = self.map_entry2 - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) + + map_exit_1: dace_nodes.MapExit = self.map_exit1 + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) output_partition = self.partition_first_outputs( state=graph, @@ -195,9 +203,9 @@ def handle_intermediate_set( intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - map_exit_2: nodes.MapExit, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, + map_exit_2: dace_nodes.MapExit, is_exclusive_set: bool, ) -> None: """This function handles the intermediate sets. @@ -238,7 +246,7 @@ def handle_intermediate_set( for out_edge in intermediate_outputs: # This is the intermediate node that, that we want to get rid of. # In shared mode we want to recreate it after the second map. - inter_node: nodes.AccessNode = out_edge.dst + inter_node: dace_nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) inter_shape = inter_desc.shape @@ -251,7 +259,7 @@ def handle_intermediate_set( if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) # Over approximation will leave us with some unneeded size one dimensions. # That are known to cause some troubles, so we will now remove them. @@ -282,7 +290,7 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dtypes.StorageType.Register, + storage=dace_dtypes.StorageType.Register, find_new_name=True, ) @@ -297,7 +305,7 @@ def handle_intermediate_set( dtype=inter_desc.dtype, find_new_name=True, ) - new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. @@ -444,7 +452,7 @@ def handle_intermediate_set( assert new_exit_memlet.data == inter_name new_exit_memlet.subset = pre_exit_edge.data.dst_subset new_exit_memlet.other_subset = ( - "0" if is_scalar else subsets.Range.from_array(inter_desc) + "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) ) new_pre_exit_conn = map_exit_2.next_connector() From 1a1a705c65f8e011e7719440fa239162e2395eed Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Aug 2024 14:59:10 +0200 Subject: [PATCH 205/235] Removed some stra `view()` call. --- .../dace/transformation_tests/test_serial_map_promoter.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index 5c9f555582..e224a08a14 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -71,7 +71,6 @@ def test_serial_map_promotion(): assert len(map_entry_1d.map.params) == 1 assert len(map_entry_2d.map.params) == 2 - sdfg.view() # Now apply the promotion sdfg.apply_transformations( gtx_transformations.SerialMapPromoter( @@ -81,8 +80,6 @@ def test_serial_map_promotion(): validate_all=True, ) - sdfg.view() - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 assert len(map_entry_2d.map.params) == 2 From 3c4523a80caafffe2efda0b1d158339b5974953e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 14:34:48 +0200 Subject: [PATCH 206/235] Fixed a bug in the `SerialMapFusion` transformation. There was a bug when the consumer side of the second map was cleaned. In this memelet tree we have to remove the iteation variable, since it is no longer there. --- .../runners/dace_fieldview/transformations/map_serial_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index 7cbf59813e..685af2878e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -414,6 +414,7 @@ def handle_intermediate_set( consumer_edge = consumer_tree.edge assert consumer_edge.data.data == inter_name consumer_edge.data.data = new_inter_name + consumer_edge.data.replace(memlet_repl) if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.subset is not None: From cbed51a7cf88f5dfde3e0bb472383af0b4b0bd42 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Aug 2024 10:52:14 +0200 Subject: [PATCH 207/235] Added the first batch of Enrique's suggestions. --- .../transformations/__init__.py | 3 +- .../transformations/auto_opt.py | 215 ++++++++++-------- .../transformations/gpu_utils.py | 78 +++---- .../transformations/k_blocking.py | 24 +- .../transformations/map_fusion_helper.py | 38 ++-- .../transformations/map_orderer.py | 8 +- .../transformations/map_promoter.py | 13 +- .../transformations/map_serial_fusion.py | 14 +- .../dace_fieldview/transformations/util.py | 56 ++--- .../dace/transformation_tests/conftest.py | 11 +- 10 files changed, 234 insertions(+), 226 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4b80710c9b..68e69f5e5d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,7 +12,7 @@ that explains the general structure and requirements on the SDFG. """ -from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_set_iteration_order, gt_simplify +from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify from .gpu_utils import ( GPUSetBlockSize, SerialMapPromoterGPU, @@ -32,7 +32,6 @@ "SerialMapFusion", "SerialMapPromoter", "SerialMapPromoterGPU", - "dace_auto_optimize", "gt_auto_optimize", "gt_gpu_transformation", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 50d003a93f..669a088104 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -13,6 +13,7 @@ import dace from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize +from dace.transformation.passes import simplify as dace_passes_simplify from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -20,37 +21,6 @@ ) -__all__ = [ - "dace_auto_optimize", - "gt_simplify", - "gt_set_iteration_order", - "gt_auto_optimize", -] - - -def dace_auto_optimize( - sdfg: dace.SDFG, - device: dace.DeviceType = dace.DeviceType.CPU, - use_gpu_storage: bool = True, - **kwargs: Any, -) -> dace.SDFG: - """This is a convenient wrapper arround DaCe's `auto_optimize` function. - - Args: - sdfg: The SDFG that should be optimized in place. - device: the device for which optimizations should be done, defaults to CPU. - use_gpu_storage: Assumes that the SDFG input is already on the GPU. - This parameter is `False` in DaCe but here is changed to `True`. - kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. - """ - return dace_aoptimize.auto_optimize( - sdfg, - device=device, - use_gpu_storage=use_gpu_storage, - **kwargs, - ) - - def gt_simplify( sdfg: dace.SDFG, validate: bool = True, @@ -74,9 +44,8 @@ def gt_simplify( However, currently nothing is customized yet, and the function just calls the simplification pass directly. """ - from dace.transformation.passes.simplify import SimplifyPass - return SimplifyPass( + return dace_passes_simplify.SimplifyPass( validate=validate, validate_all=validate_all, verbose=False, @@ -106,7 +75,9 @@ def gt_set_iteration_order( return sdfg.apply_transformations_once_everywhere( gtx_transformations.MapIterationOrder( leading_dim=leading_dim, - ) + ), + validate=validate, + validate_all=validate_all, ) @@ -115,11 +86,14 @@ def gt_auto_optimize( gpu: bool, leading_dim: Optional[gtx_common.Dimension] = None, aggressive_fusion: bool = True, + max_optimization_rounds_p2: int = 100, make_persistent: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, block_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, reuse_transients: bool = False, + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -139,7 +113,7 @@ def gt_auto_optimize( the function will only add horizonal dimensions. In this phase some optimizations inside the bigger kernels themselves might be applied as well. - 3. After the function created big kernels it will apply some optimization, + 3. After the function created big kernels/maps it will apply some optimization, inside the kernels itself. For example fuse maps inside them. 4. Afterwards it will process the map ranges and iteration order. For this the function assumes that the dimension indicated by `leading_dim` is the @@ -163,6 +137,7 @@ def gt_auto_optimize( leading_dim: Leading dimension, indicates where the stride is 1. aggressive_fusion: Be more aggressive in fusion, will lead to the promotion of certain maps. + max_optimization_rounds_p2: Maximum number of optimization rounds in phase 2. make_persistent: Turn all transients to persistent lifetime, thus they are allocated over the whole lifetime of the program, even if the kernel exits. Thus the SDFG can not be called by different threads. @@ -171,6 +146,9 @@ def gt_auto_optimize( block_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. + gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + for _all_ GPU Maps. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -197,7 +175,11 @@ def gt_auto_optimize( # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup - gt_simplify(sdfg) + gt_simplify( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, @@ -209,66 +191,17 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Compute the SDFG hash to see if something has changed. - sdfg_hash = sdfg.hash_sdfg() - # Phase 2: Kernel Creation - # We will now try to reduce the number of kernels and create large Maps/kernels. - # For this we essentially use Map fusion. We do this is a loop because - # after a graph modification followed by simplify new fusing opportunities - # might arise. We use the hash of the SDFG to detect if we have reached a - # fix point. - # TODO(phimuell): Find a better upper bound for the starvation protection. - for _ in range(100): - # Use map fusion to reduce their number and to create big kernels - # TODO(phimuell): Use a cost measurement to decide if fusion should be done. - # TODO(phimuell): Add parallel fusion transformation. Should it run after - # or with the serial one? - sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), - validate=validate, - validate_all=validate_all, - ) - - # Now do some cleanup task, that may enable further fusion opportunities. - # Note for performance reasons simplify is deferred. - phase2_cleanup = [] - phase2_cleanup.append(dace_dataflow.TrivialTaskletElimination()) - - # TODO(phimuell): Should we do this all the time or only once? (probably the later) - # TODO(phimuell): Add a criteria to decide if we should promote or not. - phase2_cleanup.append( - gtx_transformations.SerialMapPromoter( - only_toplevel_maps=True, - promote_vertical=True, - promote_horizontal=False, - promote_local=False, - ) - ) - - sdfg.apply_transformations_once_everywhere( - phase2_cleanup, - validate=validate, - validate_all=validate_all, - ) - - # Use the hash to determine if the transformations did modify the SDFG. - # If not we have optimized the SDFG as much as we could, in this phase. - old_sdfg_hash = sdfg_hash - sdfg_hash = sdfg.hash_sdfg() - if old_sdfg_hash == sdfg_hash: - break - - # The SDFG was modified by the transformations above. The SDFG was - # modified. Call Simplify and try again to further optimize. - gt_simplify(sdfg) - - else: - raise RuntimeWarning("Optimization of the SDFG did not converge.") + # Try to create kernels as large as possible. + sdfg = _gt_auto_optimize_phase_2( + sdfg=sdfg, + aggressive_fusion=aggressive_fusion, + max_optimization_rounds=max_optimization_rounds_p2, + validate=validate, + validate_all=validate_all, + ) - # Phase 3: Optimizing the kernels themselves. + # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. # Currently this only applies fusion inside Maps. sdfg.apply_transformations_repeated( gtx_transformations.SerialMapFusion( @@ -307,8 +240,6 @@ def gt_auto_optimize( # This is because how it is implemented (promotion and # fusion). However, because of its current state, this # should not happen, but we have to look into it. - gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) - gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) gtx_transformations.gt_gpu_transformation( sdfg, gpu_block_size=gpu_block_size, @@ -339,3 +270,95 @@ def gt_auto_optimize( dace_aoptimize.make_transients_persistent(sdfg, device) return sdfg + + +def _gt_auto_optimize_phase_2( + sdfg: dace.SDFG, + aggressive_fusion: bool = True, + max_optimization_rounds: int = 100, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Performs the second phase of the auto optimization process. + + As noted in the doc of `gt_auto_optimize()` the function tries to reduce the + number of kernels/maps by fusing maps. This process essentially builds on + the map fusion transformations and some clean up transformations. + + It is important to note that the fusion will only affect top level maps, i.e. + nested maps are ignored. Furthermore, the function will iteratively perform + optimizations until a fix point is reached. If this does not happen within + `max_optimization_rounds` iterations an error is generated. + + Return: + The function optimizes the SDFG in place and returns it. + + Args: + sdfg: The SDFG to optimize. + aggressive_fusion: allow more aggressive fusion by promoting maps (over + computing). + max_optimization_rounds: Maximal number of optimization rounds should be + performed. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + """ + # Compute the SDFG hash to see if something has changed. + sdfg_hash = sdfg.hash_sdfg() + + # We use a loop to optimize because we are using multiple transformations + # after the other, thus new opportunities might arise in the next round. + # We use the hash of the SDFG to detect if we have reached a fix point. + for _ in range(max_optimization_rounds): + # Use map fusion to reduce their number and to create big kernels + # TODO(phimuell): Use a cost measurement to decide if fusion should be done. + # TODO(phimuell): Add parallel fusion transformation. Should it run after + # or with the serial one? + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion( + only_toplevel_maps=True, + ), + validate=validate, + validate_all=validate_all, + ) + + # Now do some cleanup task, that may enable further fusion opportunities. + # Note for performance reasons simplify is deferred. + phase2_cleanup = [] + phase2_cleanup.append(dace_dataflow.TrivialTaskletElimination()) + + # If requested perform map promotion, this will lead to more fusion. + if aggressive_fusion: + # TODO(phimuell): Should we do this all the time or only once? + # TODO(phimuell): Add a criteria to decide if we should promote or not. + # TODO(phimuell): Add parallel map promotion? + phase2_cleanup.append( + gtx_transformations.SerialMapPromoter( + only_toplevel_maps=True, + promote_vertical=True, + promote_horizontal=False, + promote_local=False, + ) + ) + + # Perform the phase 2 cleanup. + sdfg.apply_transformations_once_everywhere( + phase2_cleanup, + validate=validate, + validate_all=validate_all, + ) + + # Use the hash to determine if the transformations did modify the SDFG. + # If not we have optimized the SDFG as much as we could, in this phase. + old_sdfg_hash = sdfg_hash + sdfg_hash = sdfg.hash_sdfg() + if old_sdfg_hash == sdfg_hash: + break + + # The SDFG was modified by the transformations above. The SDFG was + # modified. Call Simplify and try again to further optimize. + gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + else: + raise RuntimeWarning("Optimization of the SDFG did not converge.") + + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index afd12a072c..b44c545499 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -13,26 +13,20 @@ import dace from dace import properties as dace_properties, transformation as dace_transformation -from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) -__all__ = [ - "SerialMapPromoterGPU", - "GPUSetBlockSize", - "gt_gpu_transformation", - "gt_set_gpu_blocksize", -] - - def gt_gpu_transformation( sdfg: dace.SDFG, try_removing_trivial_maps: bool = True, use_gpu_storage: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -51,9 +45,13 @@ def gt_gpu_transformation( Args: sdfg: The SDFG that should be processed. try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. - use_gpu_storage: Assume that the non global memory is already on the GPU. - gpu_block_size: Set to true when the SDFG array arguments are already allocated - on GPU global memory. This will avoid the data copy from host to GPU memory. + use_gpu_storage: Assume that the non global memory is already on the GPU. This + will avoid the data copy from host to GPU memory. + gpu_block_size: The size of a thread block on the GPU. + gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + validate: Perform validation during the steps. + validate_all: Perform extensive validation. Notes: The function might modify the order of the iteration variables of some @@ -65,10 +63,6 @@ def gt_gpu_transformation( - Solve the fusing problem. - Currently only one block size for all maps is given, add more options. """ - - # You need guru level or above to use these arguments. - gpu_launch_factor: Optional[int] = kwargs.pop("gpu_launch_factor", None) - gpu_launch_bounds: Optional[int] = kwargs.pop("gpu_launch_bounds", None) assert ( len(kwargs) == 0 ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" @@ -130,13 +124,16 @@ def gt_set_gpu_blocksize( gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, ) -> Any: - """Set the block sizes of GPU Maps. + """Set the block size related properties of _all_ Maps. + + See `GPUSetBlockSize` for more information. Args: sdfg: The SDFG to process. - gpu_block_size: The block size to use. - gpu_launch_bounds: The launch bounds to use. - gpu_launch_factor: The launch factor to use. + gpu_block_size: The size of a thread block on the GPU. + launch_bounds: The value for the launch bound that should be used. + launch_factor: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number. """ xform = GPUSetBlockSize( block_size=gpu_block_size, @@ -152,21 +149,19 @@ def _gpu_block_parser( ) -> None: """Used by the setter of `GPUSetBlockSize.block_size`.""" org_val = val - if isinstance(val, tuple): + if isinstance(val, (tuple | list)): pass - elif isinstance(val, list): - val = tuple(val) elif isinstance(val, str): val = tuple(x.strip() for x in val.split(",")) + elif isinstance(val, int): + val = (val,) else: raise TypeError( - f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size." + f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." ) - if len(val) == 1: - val = (*val, 1, 1) - elif len(val) == 2: - val = (*val, 1) - elif len(val) != 3: + if 0 < len(val) <= 3: + val = [*val, *([1] * (3 - len(val)))] + else: raise ValueError(f"Can not parse block size '{org_val}': wrong length") try: val = [int(x) for x in val] @@ -190,10 +185,17 @@ def _gpu_block_getter( class GPUSetBlockSize(dace_transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. - It is also possible to set the launch bound. + The transformation will apply to all Maps that have a GPU schedule, regardless + of their dimensionality. + + The `gpu_block_size` is either a sequence, of up to three integers or a string + of up to three numbers, separated by comma (`,`). + The first number is the size of the block in `x` direction, the second for the + `y` direction and the third for the `z` direction. Missing values will be filled + with `1`. Args: - block_size: The block size that should be used. + block_size: The size of a thread block on the GPU. launch_bounds: The value for the launch bound that should be used. launch_factor: If no `launch_bounds` was given use the number of threads in a block multiplied by this number. @@ -250,7 +252,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -274,8 +276,8 @@ def can_be_applied( def apply( self, - graph: Union[SDFGState, SDFG], - sdfg: SDFG, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, ) -> None: """Modify the map as requested.""" self.map_entry.map.gpu_block_size = self.block_size @@ -313,7 +315,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -324,8 +326,6 @@ def can_be_applied( - If the top map is a trivial map. - If a valid partition exists that can be fused at all. """ - from .map_serial_fusion import SerialMapFusion - map_exit_1: dace_nodes.MapExit = self.map_exit1 map_1: dace_nodes.Map = map_exit_1.map map_entry_2: dace_nodes.MapEntry = self.map_entry2 @@ -345,7 +345,7 @@ def can_be_applied( # Check if the partition exists, if not promotion to fusing is pointless. # TODO(phimuell): Find the proper way of doing it. - serial_fuser = SerialMapFusion(only_toplevel_maps=True) + serial_fuser = gtx_transformations.SerialMapFusion(only_toplevel_maps=True) output_partition = serial_fuser.partition_first_outputs( state=graph, sdfg=sdfg, @@ -357,7 +357,7 @@ def can_be_applied( return True - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the Map Promoting. The function essentially copies the parameters and the ranges from the diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 1e8ded1c1b..0fcfe8fe46 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -16,7 +16,7 @@ subsets as dace_subsets, transformation as dace_transformation, ) -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common @@ -65,9 +65,7 @@ def __init__( block_dim: Optional[Union[gtx_common.Dimension, str]] = None, ) -> None: super().__init__() - if isinstance(block_dim, str): - pass - elif isinstance(block_dim, gtx_common.Dimension): + if isinstance(block_dim, gtx_common.Dimension): block_dim = gtx_dace_fieldview_util.get_map_variable(block_dim) if block_dim is not None: self.block_dim = block_dim @@ -80,7 +78,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -120,8 +118,8 @@ def can_be_applied( def apply( self, - graph: Union[SDFGState, SDFG], - sdfg: SDFG, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, ) -> None: """Creates a blocking map. @@ -176,7 +174,7 @@ def apply( # Now we iterate over all the output edges of the outer map and rewire them. # Note that this only handles the entry of the Map. - for out_edge in list(graph.out_edges(outer_entry)): + for out_edge in graph.out_edges(outer_entry): edge_dst: dace_nodes.Node = out_edge.dst if edge_dst in dependent_nodes: @@ -245,7 +243,7 @@ def apply( # Now rewire the Memlets that leave the caching node to go through # new inner Map. - for consumer_edge in list(graph.out_edges(caching_node)): + for consumer_edge in graph.out_edges(caching_node): new_map_conn = inner_entry.next_connector() dace_helpers.redirect_edge( state=graph, @@ -267,7 +265,7 @@ def apply( # Handle the Map exits # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. - for out_edge in list(graph.in_edges(outer_exit)): + for out_edge in graph.in_edges(outer_exit): edge_conn = out_edge.dst_conn[3:] dace_helpers.redirect_edge( state=graph, @@ -292,8 +290,8 @@ def partition_map_output( self, map_entry: dace_nodes.MapEntry, block_param: str, - state: SDFGState, - sdfg: SDFG, + state: dace.SDFGState, + sdfg: dace.SDFG, ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: """Partition the outputs of the Map. @@ -415,7 +413,7 @@ def partition_map_output( # We now make a last screening of the independent nodes. # TODO(phimuell): Make an iterative process to find the maximal set. - for independent_node in list(block_independent): + for independent_node in block_independent: if isinstance(independent_node, dace_nodes.AccessNode): if state.in_degree(independent_node) != 1: block_independent.discard(independent_node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index e8433f5cea..448b52a7c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -19,13 +19,7 @@ subsets as dace_subsets, transformation as dace_transformation, ) -from dace.sdfg import ( - SDFG, - SDFGState, - graph as dace_graph, - nodes as dace_nodes, - validation as dace_validation, -) +from dace.sdfg import graph as dace_graph, nodes as dace_nodes, validation as dace_validation from dace.transformation import helpers as dace_helpers from gt4py.next.program_processors.runners.dace_fieldview.transformations import util @@ -60,7 +54,7 @@ class MapFusionHelper(dace_transformation.SingleStateTransformation): desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) shared_transients = dace_properties.DictProperty( - key_type=SDFG, + key_type=dace.SDFG, value_type=set[str], default=None, allow_none=True, @@ -144,8 +138,8 @@ def can_be_fused( def relocate_nodes( from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, + state: dace.SDFGState, + sdfg: dace.SDFG, ) -> None: """Move the connectors and edges from `from_node` to `to_nodes` node. @@ -163,22 +157,22 @@ def relocate_nodes( """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + for empty_edge in filter(lambda e: e.data.is_empty(), state.out_edges(from_node)): dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + for empty_edge in filter(lambda e: e.data.is_empty(), state.in_edges(from_node)): dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. empty_targets: set[dace_nodes.Node] = set() - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + for empty_edge in filter(lambda e: e.data.is_empty(), state.all_edges(to_node)): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) # We now determine which edges we have to migrate, for this we are looking at # the incoming edges, because this allows us also to detect dynamic map ranges. - for edge_to_move in list(state.in_edges(from_node)): + for edge_to_move in state.in_edges(from_node): assert isinstance(edge_to_move.dst_conn, str) if not edge_to_move.dst_conn.startswith("IN_"): @@ -209,10 +203,10 @@ def relocate_nodes( new_conn = to_node.next_connector(old_conn) to_node.add_in_connector("IN_" + new_conn) - for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + for e in state.in_edges_by_connector(from_node, "IN_" + old_conn): dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) to_node.add_out_connector("OUT_" + new_conn) - for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + for e in state.out_edges_by_connector(from_node, "OUT_" + old_conn): dace_helpers.redirect_edge( state, e, new_src=to_node, new_src_conn="OUT_" + new_conn ) @@ -239,8 +233,8 @@ def relocate_nodes( def map_parameter_compatible( map_1: dace_nodes.Map, map_2: dace_nodes.Map, - state: Union[SDFGState, SDFG], - sdfg: SDFG, + state: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, ) -> bool: """Checks if the parameters of `map_1` are compatible with `map_2`. @@ -352,8 +346,8 @@ def is_interstate_transient( def partition_first_outputs( self, - state: SDFGState, - sdfg: SDFG, + state: dace.SDFGState, + sdfg: dace.SDFG, map_exit_1: dace_nodes.MapExit, map_entry_2: dace_nodes.MapEntry, ) -> Union[ @@ -448,8 +442,8 @@ def partition_first_outputs( # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) + inner_collector_edges = state.in_edges_by_connector( + intermediate_node, "IN_" + out_edge.src_conn[3:] ) if len(inner_collector_edges) > 1: return None diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index f7d447fdc6..ee6b2c43fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -10,7 +10,7 @@ import dace from dace import properties as dace_properties, transformation as dace_transformation -from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes +from dace.sdfg import nodes as dace_nodes from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util @@ -68,7 +68,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -93,8 +93,8 @@ def can_be_applied( def apply( self, - graph: Union[SDFGState, SDFG], - sdfg: SDFG, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, ) -> None: """Performs the actual parameter reordering. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 2f0d4ee261..9eea80cd55 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -14,12 +14,7 @@ subsets as dace_subsets, transformation as dace_transformation, ) -from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes - - -__all__ = [ - "SerialMapPromoter", -] +from dace.sdfg import nodes as dace_nodes @dace_properties.make_properties @@ -151,7 +146,7 @@ def __init__( def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -212,7 +207,7 @@ def can_be_applied( return True - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the actual Map promoting. Add all parameters that `self.source_map` has but `self.map_to_promote` @@ -326,7 +321,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index 685af2878e..019e03563f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -19,7 +19,7 @@ symbolic as dace_symbolic, transformation as dace_transformation, ) -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace.sdfg import graph as dace_graph, nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper @@ -78,7 +78,7 @@ def expressions(cls) -> Any: def can_be_applied( self, - graph: Union[SDFGState, SDFG], + graph: Union[dace.SDFGState, dace.SDFG], expr_index: int, sdfg: dace.SDFG, permissive: bool = False, @@ -201,8 +201,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non @staticmethod def handle_intermediate_set( intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], - state: SDFGState, - sdfg: SDFG, + state: dace.SDFGState, + sdfg: dace.SDFG, map_exit_1: dace_nodes.MapExit, map_entry_2: dace_nodes.MapEntry, map_exit_2: dace_nodes.MapExit, @@ -253,9 +253,7 @@ def handle_intermediate_set( # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) + pre_exit_edges = state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] @@ -422,7 +420,7 @@ def handle_intermediate_set( # The edge that leaves the second map entry was already deleted. # We will now delete the edges that brought the data. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + for edge in state.in_edges_by_connector(map_entry_2, in_conn_name): assert edge.src == inter_node state.remove_edge(edge) map_entry_2.remove_in_connector(in_conn_name) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index f40749de23..a54c882842 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -23,11 +23,8 @@ def is_nested_sdfg( if isinstance(sdfg, dace_nodes.NestedSDFG): return True elif isinstance(sdfg, dace.SDFG): - if sdfg.parent_nsdfg_node is not None: - return True - return False - else: - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") + return sdfg.parent_nsdfg_node is not None + raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") def all_nodes_between( @@ -68,23 +65,19 @@ def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: to_visit: list[dace_nodes.Node] = [begin] seen: set[dace_nodes.Node] = set() - found_end: bool = False while len(to_visit) > 0: - n: dace_nodes.Node = to_visit.pop() - if n == end: - found_end = True - continue - elif n in seen: - continue - seen.add(n) - to_visit.extend(next_nodes(n)) + node: dace_nodes.Node = to_visit.pop() + if node != end and node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) - if not found_end: + # If `end` was not found we have to return `None` to indicate this. + if end not in seen: return None - seen.discard(begin) - return seen + # `begin` and `end` are not included in the output set. + return seen - {begin, end} def find_downstream_consumers( @@ -113,14 +106,13 @@ def find_downstream_consumers( """ if isinstance(begin, dace_graph.MultiConnectorEdge): to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] - elif reverse: - to_visit = list(state.in_edges(begin)) else: - to_visit = list(state.out_edges(begin)) + to_visit = state.in_edges(begin) if reverse else state.out_edges(begin) + seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() - while len(to_visit) != 0: + while len(to_visit) > 0: curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst @@ -129,26 +121,28 @@ def find_downstream_consumers( seen.add(curr_edge) if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): - if reverse: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - else: + if not reverse: # In forward mode a Map entry could also mean the definition of a # dynamic map range. - if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, dace_nodes.MapEntry + if isinstance(next_node, dace_nodes.MapEntry) and ( + not curr_edge.dst_conn.startswith("IN_") ): - # This edge defines a dynamic map range, which is a consumer if not only_tasklets: found.add((next_node, curr_edge)) continue target_conn = curr_edge.dst_conn[3:] new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) + else: + target_conn = curr_edge.src_conn[4:] + new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) to_visit.extend(new_edges) - del new_edges + + elif only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): + # We are only interested in Tasklets but have not found one. Thus we + # ignore the node. + pass + else: - if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): - continue found.add((next_node, curr_edge)) return found diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py index 72e76a63e2..0e85e1d9d1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -22,8 +22,15 @@ @pytest.fixture(autouse=True) -def _set_dace_settings() -> Generator[None, None, None]: - """Customizes DaCe settings during the tests.""" +def set_dace_settings() -> Generator[None, None, None]: + """Sets the common DaCe settings for the tests. + + The function will modify the following settings: + - `optimizer.match_exception` exceptions during the pattern matching stage, + especially inside `can_be_applied()` are not ignored. + - `compiler.allow_view_arguments` allow that NumPy views can be passed to + `CompiledSDFG` objects as arguments. + """ with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) dace.Config.set("compiler", "allow_view_arguments", value=True) From ca717357856eaf7045e990427f67fb66a5c4fe91 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Aug 2024 11:31:08 +0200 Subject: [PATCH 208/235] Fixed a bug in the map promoter. It was a problem with the string matching rule. --- .../dace_fieldview/transformations/map_promoter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 9eea80cd55..885ba4e09a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -160,6 +160,7 @@ def can_be_applied( - If the parameter of the second map are compatible with each other. - If a dimension would be promoted that should not. """ + assert self.expr_index == expr_index map_to_promote_entry: dace_nodes.MapEntry = self.map_to_promote(state=graph, sdfg=sdfg) map_to_promote: dace_nodes.Map = map_to_promote_entry.map source_map_entry: dace_nodes.MapEntry = self.source_map(state=graph, sdfg=sdfg) @@ -191,11 +192,11 @@ def can_be_applied( if not self.promote_all: dimension_identifier: list[str] = [] if self.promote_local: - dimension_identifier.append("__gtx_localdim") + dimension_identifier.append("_gtx_localdim") if self.promote_vertical: - dimension_identifier.append("__gtx_vertical") + dimension_identifier.append("_gtx_vertical") if self.promote_horizontal: - dimension_identifier.append("__gtx_horizontal") + dimension_identifier.append("_gtx_horizontal") if not dimension_identifier: return False for missing_map_param in missing_map_parameters: @@ -276,6 +277,7 @@ def missing_map_params( for param_to_check in curr_params_set: curr_range = curr_ranges[curr_param_to_idx[param_to_check]] source_range = source_ranges[source_param_to_idx[param_to_check]] + # TODO(phimuell): Use simplify if curr_range != source_range: return None return list(source_params_set - curr_params_set) From 83a5fe4e62a2bccea77179b59b0b4898109bd9d4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Aug 2024 14:38:19 +0200 Subject: [PATCH 209/235] fixup! Added the first batch of Enrique's suggestions. --- .../dace_fieldview/transformations/map_fusion_helper.py | 4 ++-- .../dace_fieldview/transformations/map_serial_fusion.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 448b52a7c8..ced6f41ff2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -442,8 +442,8 @@ def partition_first_outputs( # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. - inner_collector_edges = state.in_edges_by_connector( - intermediate_node, "IN_" + out_edge.src_conn[3:] + inner_collector_edges = list( + state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) ) if len(inner_collector_edges) > 1: return None diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index 019e03563f..a07b613ab1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -253,7 +253,9 @@ def handle_intermediate_set( # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] From 0ee90f5eb859269318d29428f35a9a01c7c6f594 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Aug 2024 14:39:00 +0200 Subject: [PATCH 210/235] First new version of the k blocking. --- .../transformations/auto_opt.py | 3 + .../transformations/k_blocking.py | 473 +++++++++++------- .../transformation_tests/test_k_blocking.py | 159 +++++- 3 files changed, 455 insertions(+), 180 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 669a088104..2f9de2327a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -331,6 +331,7 @@ def _gt_auto_optimize_phase_2( # TODO(phimuell): Should we do this all the time or only once? # TODO(phimuell): Add a criteria to decide if we should promote or not. # TODO(phimuell): Add parallel map promotion? + print(">>>>>>>>>>>>>>>>>>>>>>>> AGGRESSIVE FUSION") phase2_cleanup.append( gtx_transformations.SerialMapPromoter( only_toplevel_maps=True, @@ -339,6 +340,8 @@ def _gt_auto_optimize_phase_2( promote_local=False, ) ) + else: + print("<<<<<<<<<<<<<<<<<<<<<<<<<<<< NO AGGRESSIVE FUSION") # Perform the phase 2 cleanup. sdfg.apply_transformations_once_everywhere( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 0fcfe8fe46..a08ae19dec 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -import functools from typing import Any, Optional, Union import dace @@ -44,6 +43,12 @@ class KBlocking(dace_transformation.SingleStateTransformation): The function will also change the name of the outer map, it will append `_blocked` to it. + + Todo: + - Allow that independent nodes does not need to have direct connection + to the map, instead it is enough that they only have connections to + independent nodes. + - Investigate if the iteration should always start at zero. """ blocking_size = dace_properties.Property( @@ -86,9 +91,9 @@ def can_be_applied( """Test if the map can be blocked. The test involves: - - Toplevel map. + - The map must be at the top level. - The map shall not be serial. - - The block dimension must be present (exact match). + - The block dimension must be present (exact string match). - The map range must have step size of 1. - The partition must exists (see `partition_map_output()`). """ @@ -169,25 +174,90 @@ def apply( outer_map.range[block_idx] = coarse_block_range outer_map.label = f"{outer_map.label}_blocked" - # Contains the independent nodes that are already relocated. + # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() - # Now we iterate over all the output edges of the outer map and rewire them. - # Note that this only handles the entry of the Map. - for out_edge in graph.out_edges(outer_entry): - edge_dst: dace_nodes.Node = out_edge.dst + # We now handle all independent nodes, this means that all of their + # _output_ edges have to go through the new inner map and the Memlets need + # modifications, because of the block parameter. + for independent_node in independent_nodes: + for out_edge in graph.out_edges(independent_node): + edge_dst: dace_nodes.Node = out_edge.dst + relocated_nodes.add(edge_dst) + + # If destination of this edge is also independent we do not need + # to handle it, because that node will also be before the new + # inner serial map. + if edge_dst in independent_nodes: + continue + + # Now split `out_edge` such that it passes through the new inner entry. + # We do not need to modify the subsets, i.e. replacing the variable + # on which we block, because the node is independent and the outgoing + # new inner map entry iterate over the blocked variable. + new_map_conn = inner_entry.next_connector() + dace_helpers.redirect_edge( + state=graph, + edge=out_edge, + new_dst=inner_entry, + new_dst_conn="IN_" + new_map_conn, + ) + # TODO(phimuell): Check if there might be a subset error. + graph.add_edge( + inner_entry, + "OUT_" + new_map_conn, + out_edge.dst, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + inner_entry.add_in_connector("IN_" + new_map_conn) + inner_entry.add_out_connector("OUT_" + new_map_conn) + + # Now we handle the dependent nodes, they differ from the independent nodes + # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. + for dependent_node in dependent_nodes: + for in_edge in graph.in_edges(dependent_node): + edge_src: dace_nodes.Node = in_edge.src + + # Since the independent nodes were already processed, and they process + # their output we have to check for this. We do this by checking if + # the source of the edge is the new inner map entry. + if edge_src is inner_entry: + assert dependent_node in relocated_nodes + continue + + # A dependent node has at least one connection to the outer map entry. + # And these are the only connections that we must handle, since other + # connections come from independent nodes, and were already handled + # or are inner nodes. + if edge_src is not outer_entry: + continue + + # If we encounter an empty Memlet we just just attach it to the + # new inner map entry. Note the partition function ensures that + # either all edges are empty or non. + if in_edge.data.is_empty(): + assert ( + edge_src is outer_entry + ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + dace_helpers.redirect_edge(state=graph, edge=in_edge, new_src=inner_entry) + continue - if edge_dst in dependent_nodes: - # This is the simple case as we just have to rewire the edge - # and make a connection between the outer and inner map. - assert not out_edge.data.is_empty() - edge_conn: str = out_edge.src_conn[4:] + # Because of the definition of a dependent node and the processing + # order, their incoming edges either point to the outer map or + # are already handled. + if edge_src is not outer_entry: + sdfg.view() + assert ( + edge_src is outer_entry + ), f"Expected to find source '{outer_entry}' but found '{edge_src}' || {dependent_node} // {edge_src in independent_nodes}." + edge_conn: str = in_edge.src_conn[4:] # Must be before the handling of the modification below # Note that this will remove the original edge from the SDFG. dace_helpers.redirect_edge( state=graph, - edge=out_edge, + edge=in_edge, new_src=inner_entry, new_src_conn="OUT_" + edge_conn, ) @@ -211,65 +281,30 @@ def apply( "OUT_" + edge_conn, inner_entry, "IN_" + edge_conn, - copy.deepcopy(out_edge.data), + copy.deepcopy(in_edge.data), ) inner_entry.add_in_connector("IN_" + edge_conn) inner_entry.add_out_connector("OUT_" + edge_conn) - continue - - elif edge_dst in relocated_nodes: - # The node was already fully handled in the `else` clause. - continue - - else: - # Relocate the node and make the reconnection. - # Different from the dependent case we will handle all the edges - # of the node in one go. - relocated_nodes.add(edge_dst) - # In order to be useful we have to temporarily store the data the - # independent node generates - assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift - if isinstance(edge_dst, dace_nodes.AccessNode): - # The independent node is an access node, so we can use it directly. - caching_node: dace_nodes.AccessNode = edge_dst - else: - # The dependent node is not an access node. For now we will - # just use the next node, with some restriction. - # TODO(phimuell): create an access node in this case instead. - caching_node = next(iter(graph.out_edges(edge_dst))).dst - assert graph.in_degree(caching_node) == 1 - assert isinstance(caching_node, dace_nodes.AccessNode) - - # Now rewire the Memlets that leave the caching node to go through - # new inner Map. - for consumer_edge in graph.out_edges(caching_node): - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=graph, - edge=consumer_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, - ) - graph.add_edge( - inner_entry, - "OUT_" + new_map_conn, - consumer_edge.dst, - consumer_edge.dst_conn, - copy.deepcopy(consumer_edge.data), - ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) - continue + # In certain cases it might happen that we need to create an empty + # Memlet between the outer map entry and the inner one. + if graph.in_degree(inner_entry) == 0: + graph.add_edge( + outer_entry, + None, + inner_entry, + None, + dace.Memlet(), + ) # Handle the Map exits # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. - for out_edge in graph.in_edges(outer_exit): - edge_conn = out_edge.dst_conn[3:] + for in_edge in graph.in_edges(outer_exit): + edge_conn = in_edge.dst_conn[3:] dace_helpers.redirect_edge( state=graph, - edge=out_edge, + edge=in_edge, new_dst=inner_exit, new_dst_conn="IN_" + edge_conn, ) @@ -277,8 +312,8 @@ def apply( inner_exit, "OUT_" + edge_conn, outer_exit, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), + in_edge.dst_conn, + copy.deepcopy(in_edge.data), ) inner_exit.add_in_connector("IN_" + edge_conn) inner_exit.add_out_connector("OUT_" + edge_conn) @@ -293,17 +328,33 @@ def partition_map_output( state: dace.SDFGState, sdfg: dace.SDFG, ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: - """Partition the outputs of the Map. + """Partition the of the nodes of the Map. - The partition will only look at the direct intermediate outputs of the - Map. The outputs will be two sets, defined as: - - The independent outputs `\mathcal{I}`: - These are output nodes, whose output does not depend on the blocked + The outputs will be two sets, defined as: + - The independent nodes `\mathcal{I}`: + These are the nodes, whose output does not depend on the blocked dimension. These nodes can be relocated between the outer and inner map. - - The dependent output `\mathcal{D}`: - These are the output nodes, whose output depend on the blocked dimension. + Nodes in these set does not necessarily have a direct edge to `map_entry`. + However, the exists a path from `map_entry` to any node in this set. + - The dependent nodes `\mathcal{D}`: + These are the nodes, whose output depend on the blocked dimension. Thus they can not be relocated between the two maps, but will remain - inside the inner scope. + inside the inner scope. All these nodes have at least one edge to `map_entry`. + + The function uses an iterative method to compute the set of independent nodes. + In each iteration the function will look classify all nodes that have an + incoming edge originating either at `map_entry` or from a node that was + already classified as independent. This is repeated until no new independent + nodes are found. This means that independent nodes does not necessarily have + a direct connection to `map_entry`. + + The dependent nodes on the other side always have a direct edge to `map_entry`. + As they are the set of nodes that are adjacent to `map_entry` but are not + independent. + + For the sake of arguments, all nodes, except the map entry and exit nodes, + that are inside the scope and are not classified as dependent or independent + are known as "inner nodes". In case the function fails to compute the partition `None` is returned. @@ -314,122 +365,186 @@ def partition_map_output( sdfg: The SDFG in which we operate on. Note: + - The function will only inspect the direct children of the map entry. - Currently this function only considers the input Memlets and the - `used_symbol` properties of a Tasklet. - - Furthermore only the first level is inspected. + `used_symbol` properties of a Tasklet to determine if a Tasklet is dependent. """ - block_independent: set[dace_nodes.Node] = set() # `\mathcal{I}` - block_dependent: set[dace_nodes.Node] = set() # `\mathcal{D}` - - # Find all nodes that are adjacent to the map entry. - nodes_to_partition: set[dace_nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} - - # Now we examine every node and assign them to one of the sets. - # Note that this is only tentative and we will later inspect the - # outputs of the independent node and reevaluate their classification. - for node in nodes_to_partition: - # Filter out all nodes that we can not (yet) handle. - if not isinstance(node, (dace_nodes.Tasklet, dace_nodes.AccessNode)): - return None + independent_nodes: set[dace_nodes.Node] = set() # `\mathcal{I}` + + while True: + # Find all the nodes that we have to classify in this iteration. + # - All nodes adjacent to `map_entry` + # - All nodes adjacent to independent nodes. + nodes_to_classify: set[dace_nodes.Node] = { + edge.dst for edge in state.out_edges(map_entry) + } + for independent_node in independent_nodes: + nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) + nodes_to_classify.difference_update(independent_nodes) + + # Now classify each node + found_new_independent_node = False + for node_to_classify in nodes_to_classify: + class_res = self.classify_node( + node_to_classify=node_to_classify, + map_entry=map_entry, + block_param=block_param, + independent_nodes=independent_nodes, + state=state, + sdfg=sdfg, + ) - # Check if we have a strange Tasklet or if it uses the symbol inside it. - if isinstance(node, dace_nodes.Tasklet): - if node.side_effects: + # Check if the partition exists. + if class_res is None: return None - if block_param in node.free_symbols: - block_dependent.add(node) - continue + if class_res is True: + found_new_independent_node = True - # An independent node can (for now) only have one output. - # TODO(phimuell): Lift this restriction. - if state.out_degree(node) != 1: - block_dependent.add(node) - continue + # If we found a new independent node then we have to continue. + if not found_new_independent_node: + break - # Now we have to understand how the node generates its data. - # For this we have to look at all the edges that feed information to it. - edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges(node)) + # After the independent set is computed compute the set of dependent nodes + # as the set of all nodes adjacent to `map_entry` that are not dependent. + dependent_nodes: set[dace_nodes.Node] = { + edge.dst for edge in state.out_edges(map_entry) if edge.dst not in independent_nodes + } - # If all edges are empty, i.e. they are only needed to keep the - # node inside the scope, consider it as independent. However, they have - # to be associated to the outer map. - if all(edge.data.is_empty() for edge in edges): - if not all(edge.src is map_entry for edge in edges): - return None - block_independent.add(node) - continue + return (independent_nodes, dependent_nodes) - # Currently we do not allow that a node has a mix of empty and non - # empty Memlets, it is all or nothing. - if any(edge.data.is_empty() for edge in edges): - return None + def classify_node( + self, + node_to_classify: dace_nodes.Node, + map_entry: dace_nodes.MapEntry, + block_param: str, + independent_nodes: set[dace_nodes.Node], + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool | None: + """Internal function for classifying a single node. + + The general rule to classify if a node is independent are: + - The node must be a Tasklet or an AccessNode, in all other cases the + partition does not exist. + - `free_symbols` of the nodes shall not contain the `block_param`. + - All incoming _empty_ edges must be connected to the map entry. + - A node either has only empty Memlets or none of them. + - Incoming Memlets does not depend on the `block_param`. + - All incoming edges must start either at `map_entry` or at dependent nodes. + - All output Memlets are non empty. + + Returns: + The function returns `True` if `node_to_classify` is considered independent. + In this case the function will add the node to `independent_nodes`. + If the function returns `False` the node was classified as a dependent node. + The function will return `None` if the node can not be classified, in this + case the partition does not exist. - # If the node gets information from other nodes than the map entry - # we classify it as a dependent node. But there can be situations where - # the node could still be independent, for example if it is connected - # to a independent node, then it could be independent itself. - # TODO(phimuell): Consider independent node as "equal" to the map. - if any(edge.src is not map_entry for edge in edges): - block_dependent.add(node) - continue + Args: + node_to_classify: The node that should be classified. + map_entry: The entry of the map that should be partitioned. + block_param: The iteration parameter that should be blocked. + independent_nodes: The set of nodes that was already classified as + independent, in case the node is classified as independent the set is + updated. + state: The state containing the map. + sdfg: The SDFG that is processed. + """ - # Now we have to look at the edges individually. - # If this loop ends normally, i.e. it goes into its `else` - # clause then we classify the node as independent. - for edge in edges: - memlet: dace.Memlet = edge.data - src_subset: dace_subsets.Subset | None = memlet.src_subset - dst_subset: dace_subsets.Subset | None = memlet.dst_subset - edge_desc: dace.data.Data = sdfg.arrays[memlet.data] - edge_desc_size = functools.reduce(lambda a, b: a * b, edge_desc.shape) - - if memlet.is_empty(): - # Empty Memlets do not impose any restrictions. - continue - if memlet.num_elements() == edge_desc_size: - # The whole source array is consumed, which is not a problem. - continue + # We are only able to handle certain kind of nodes, so screening them. + if isinstance(node_to_classify, dace_nodes.Tasklet): + if node_to_classify.side_effects: + # TODO(phimuell): Think of handling it. + return None + elif isinstance(node_to_classify, dace_nodes.AccessNode): + # AccessNodes need to have some special properties. + node_desc: dace.data.Data = node_to_classify.desc(sdfg) - # Now we have to look at the source and destination set of the Memlet. - subsets_to_inspect: list[dace_subsets.Subset] = [] - if dst_subset is not None: - subsets_to_inspect.append(dst_subset) - if src_subset is not None: - subsets_to_inspect.append(src_subset) - - # If a subset needs the block variable then the node is not - # independent from the block variable. - if any(block_param in subset.free_symbols for subset in subsets_to_inspect): - break + if isinstance(node_desc, dace.data.View): + # Views are forbidden. + return None + if node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: + # The access node has to life fully within the scope. + return None + else: + # Any other node type we can not handle, so the partition can not exist. + # TODO(phimuell): Try to handle certain kind of library nodes. + return None + + # Now we have to understand how the node generates its data. For this we have + # to look at all the incoming edges. + in_edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges(node_to_classify) + ) + + # In a first phase we will only look if the partition exists or not. + # We will therefore not check if the node is independent or not, since + # for these classification to make sense the partition has to exist in the + # first place. + + # Either all incoming edges of a node are empty or none of them. If it has + # empty edges, they are only allowed to come from the map entry. + found_empty_edges, found_nonempty_edges = False, False + for in_edge in in_edges: + if in_edge.data.is_empty(): + found_empty_edges = True + if in_edge.src is not map_entry: + # TODO(phimuell): Lift this restriction. + return None else: - # The loop ended normally, thus we did not found anything that made us - # _think_ that the node is _not_ an independent node. We will later - # also inspect the output, which might reclassify the node - block_independent.add(node) - - # If the node is not independent then it must be dependent, my dear Watson. - if node not in block_independent: - block_dependent.add(node) - - # We now make a last screening of the independent nodes. - # TODO(phimuell): Make an iterative process to find the maximal set. - for independent_node in block_independent: - if isinstance(independent_node, dace_nodes.AccessNode): - if state.in_degree(independent_node) != 1: - block_independent.discard(independent_node) - block_dependent.add(independent_node) + found_nonempty_edges = True + + # Test if we found a mixture of empty and nonempty edges. + if found_empty_edges and found_nonempty_edges: + return None + assert ( + found_empty_edges or found_nonempty_edges + ), f"Node '{node_to_classify}' inside '{map_entry}' without an input connection." + + # Requiring that all output Memlets are non empty implies, because we are + # inside a scope, that there exists an output. + if any(out_edge.data.is_empty() for out_edge in state.out_edges(node_to_classify)): + return None + + # Now we have ensured that the partition exists, thus we will now evaluate + # if the node is independent or dependent. + + # Test if the body of the Tasklet depends on the block variable. + if ( + isinstance(node_to_classify, dace_nodes.Tasklet) + and block_param in node_to_classify.free_symbols + ): + return False + + # Now we have to look at incoming edges individually. + # We will inspect the subset of the Memlet to see if they depend on the + # block variable. If this loop ends normally, then we classify the node + # as independent and the node is added to `independent_nodes`. + for in_edge in in_edges: + memlet: dace.Memlet = in_edge.data + src_subset: dace_subsets.Subset | None = memlet.src_subset + dst_subset: dace_subsets.Subset | None = memlet.dst_subset + + if memlet.is_empty(): # Empty Memlets do not impose any restrictions. continue - for out_edge in state.out_edges(independent_node): - if ( - (not isinstance(out_edge.dst, dace_nodes.AccessNode)) - or (state.in_degree(out_edge.dst) != 1) - or (out_edge.dst.desc(sdfg).lifetime != dace.dtypes.AllocationLifetime.Scope) - ): - block_independent.discard(independent_node) - block_dependent.add(independent_node) - break - - assert not block_dependent.intersection(block_independent) - assert (len(block_independent) + len(block_dependent)) == len(nodes_to_partition) - - return (block_independent, block_dependent) + + # Now we have to look at the source and destination set of the Memlet. + subsets_to_inspect: list[dace_subsets.Subset] = [] + if dst_subset is not None: + subsets_to_inspect.append(dst_subset) + if src_subset is not None: + subsets_to_inspect.append(src_subset) + + # If a subset needs the block variable then the node is not independent + # but dependent. + if any(block_param in subset.free_symbols for subset in subsets_to_inspect): + return False + + # The edge must either originate from `map_entry` or from an independent + # node if not it is dependent. + if not (in_edge.src is map_entry or in_edge.src in independent_nodes): + return False + + # Loop ended normally, thus we classify the node as independent. + independent_nodes.add(node_to_classify) + return True diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py index 91d76ebd39..d2e064386d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py @@ -11,7 +11,7 @@ import copy import numpy as np -from dace.sdfg import nodes as dace_nodes +from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, @@ -43,6 +43,93 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np return sdfg, lambda a, b: a + b.reshape((-1, 1)) +def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np.ndarray]]: + """Generates an SDFG that has chained Tasklets that are independent. + + The bottom Tasklet is the only dependent Tasklet. + """ + sdfg = dace.SDFG("only_dependent") + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + sdfg.add_array("b", ("N",), dace.float64, transient=False) + sdfg.add_array("c", ("N", "M"), dace.float64, transient=False) + sdfg.add_scalar("tmp1", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + a, b, c, tmp1, tmp2 = (state.add_access(name) for name in ["a", "b", "c", "tmp1", "tmp2"]) + + # First independent Tasklet. + task1 = state.add_tasklet( + "task1_indepenent", + inputs={ + "__in0", # <- `b[i]` + }, + outputs={ + "__out0", # <- `tmp1` + }, + code="__out0 = __in0 + 3.0", + ) + + # This is the second independent Tasklet. + task2 = state.add_tasklet( + "task2_indepenent", + inputs={ + "__in0", # <- `tmp1` + "__in1", # <- `b[i]` + }, + outputs={ + "__out0", # <- `tmp2` + }, + code="__out0 = __in0 + __in1", + ) + + # This is the third Tasklet, which is dependent. + task3 = state.add_tasklet( + "task3_dependent", + inputs={ + "__in0", # <- `tmp2` + "__in1", # <- `a[i, j]` + }, + outputs={ + "__out0", # <- `c[i, j]`. + }, + code="__out0 = __in0 + __in1", + ) + + # Now create the map + mentry, mexit = state.add_map( + "map", + ndrange={"i": "0:N", "j": "0:M"}, + ) + + # Now assemble everything. + state.add_edge(mentry, "OUT_b", task1, "__in0", dace.Memlet("b[i]")) + state.add_edge(task1, "__out0", tmp1, None, dace.Memlet("tmp1[0]")) + + state.add_edge(tmp1, None, task2, "__in0", dace.Memlet("tmp1[0]")) + state.add_edge(mentry, "OUT_b", task2, "__in1", dace.Memlet("b[i]")) + state.add_edge(task2, "__out0", tmp2, None, dace.Memlet("tmp2[0]")) + + state.add_edge(tmp2, None, task3, "__in0", dace.Memlet("tmp2[0]")) + state.add_edge(mentry, "OUT_a", task3, "__in1", dace.Memlet("a[i, j]")) + state.add_edge(task3, "__out0", mexit, "IN_c", dace.Memlet("c[i, j]")) + + state.add_edge(a, None, mentry, "IN_a", sdfg.make_array_memlet("a")) + state.add_edge(b, None, mentry, "IN_b", sdfg.make_array_memlet("b")) + state.add_edge(mexit, "OUT_c", c, None, sdfg.make_array_memlet("c")) + for name in ["a", "b"]: + mentry.add_in_connector("IN_" + name) + mentry.add_out_connector("OUT_" + name) + mexit.add_in_connector("IN_c") + mexit.add_out_connector("OUT_c") + + dace_propagation.propagate_states(sdfg) + sdfg.validate() + + return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3)) + + def test_only_dependent(): """Just applying the transformation to the SDFG. @@ -142,3 +229,73 @@ def test_intermediate_access_node(): c[:] = 0 sdfg(a=a, b=b, c=c, N=N, M=M) assert np.allclose(ref, c) + + +def test_chained_access() -> None: + """Test if chained access works.""" + sdfg, reff = _get_chained_sdfg() + state: dace.SDFGState = next(iter(sdfg.states())) + + N, M = 100, 10 + a = np.random.rand(N, M) + b = np.random.rand(N) + c = np.zeros_like(a) + ref = reff(a, b) + + # Before the optimization + sdfg(a=a, b=b, c=c, M=M, N=N) + assert np.allclose(c, ref) + c[:] = 0 + + # Apply the transformation. + ret = sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + validate=True, + validate_all=True, + ) + assert ret == 1, f"Expected that the transformation was applied 1 time, but it was {ret}." + + # Now run the SDFG to see if it is still the same + sdfg(a=a, b=b, c=c, M=M, N=N) + assert np.allclose(c, ref) + + # Now look for the outer map. + outer_map = None + for node in state.nodes(): + if not isinstance(node, dace_nodes.MapEntry): + continue + if state.scope_dict()[node] is None: + assert ( + outer_map is None + ), f"Found multiple outer maps, first '{outer_map}', second '{node}'." + outer_map = node + assert outer_map is not None, "Could not found the outer map." + assert len(outer_map.map.params) == 2 + + # Now inspect the SDFG if the transformation was applied correctly. + first_level_tasklets: list[dace_nodes.Tasklet] = [] + inner_map: list[dace_nodes.MapEntry] = [] + + for edge in state.out_edges(outer_map): + node: dace_nodes.Node = edge.dst + if isinstance(node, dace_nodes.Tasklet): + first_level_tasklets.append(node) + elif isinstance(node, dace_nodes.MapEntry): + inner_map.append(node) + else: + assert False, f"Found unexpected node '{type(node).__name__}'." + + # Test what we found + assert len(first_level_tasklets) == 2 + assert len(set(first_level_tasklets)) == 2 + assert len(inner_map) == 1 + assert state.scope_dict()[inner_map[0]] is outer_map + + # Now we look inside the inner map + # There we expect to find one Tasklet. + inner_scope = state.scope_subgraph(next(iter(inner_map)), False, False) + assert inner_scope.number_of_nodes() == 1 + inner_tasklet = next(iter(inner_scope.nodes())) + + assert isinstance(inner_tasklet, dace_nodes.Tasklet) + assert inner_tasklet not in first_level_tasklets From 5d875fdce75c1ba1011ee29e10330411b0a81270 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 27 Aug 2024 09:02:14 +0200 Subject: [PATCH 211/235] Further, refactored the KBlock transformation. It is now more separated and should be more logical. However, it still has some large functions. --- .../transformations/auto_opt.py | 11 +- .../transformations/k_blocking.py | 310 +++++++++++------- .../transformation_tests/test_k_blocking.py | 6 +- 3 files changed, 190 insertions(+), 137 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 2f9de2327a..a3b0e96fba 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -89,7 +89,7 @@ def gt_auto_optimize( max_optimization_rounds_p2: int = 100, make_persistent: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, - block_dim: Optional[gtx_common.Dimension] = None, + blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, @@ -143,7 +143,7 @@ def gt_auto_optimize( Thus the SDFG can not be called by different threads. gpu_block_size: The thread block size for maps in GPU mode, currently only one for all. - block_dim: On which dimension blocking should be applied. + blocking_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. @@ -224,11 +224,11 @@ def gt_auto_optimize( ) # Phase 5: Apply blocking - if block_dim is not None: + if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.KBlocking( blocking_size=blocking_size, - block_dim=block_dim, + blocking_parameter=blocking_dim, ), validate=validate, validate_all=validate_all, @@ -331,7 +331,6 @@ def _gt_auto_optimize_phase_2( # TODO(phimuell): Should we do this all the time or only once? # TODO(phimuell): Add a criteria to decide if we should promote or not. # TODO(phimuell): Add parallel map promotion? - print(">>>>>>>>>>>>>>>>>>>>>>>> AGGRESSIVE FUSION") phase2_cleanup.append( gtx_transformations.SerialMapPromoter( only_toplevel_maps=True, @@ -340,8 +339,6 @@ def _gt_auto_optimize_phase_2( promote_local=False, ) ) - else: - print("<<<<<<<<<<<<<<<<<<<<<<<<<<<< NO AGGRESSIVE FUSION") # Perform the phase 2 cleanup. sdfg.apply_transformations_once_everywhere( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index a08ae19dec..1e0ea319c7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -27,28 +27,25 @@ class KBlocking(dace_transformation.SingleStateTransformation): """Applies k-Blocking with separation on a Map. This transformation takes a multidimensional Map and performs blocking on a - dimension, that is commonly called "k", but identified with `block_dim`. - - All dimensions except `k` are unaffected by this transformation. In the outer - Map will be replace the `k` range, currently `k = 0:N`, with - `__coarse_k = 0:N:B`, where `N` is the original size of the range and `B` - is the block size, passed as `blocking_size`. The transformation also handles the - case if `N % B != 0`. - The transformation will then create an inner sequential map with - `k = __coarse_k:(__coarse_k + B)`. - - However, before the split the transformation examines all adjacent nodes of - the original Map. If a node does not depend on `k`, then the node will be - put between the two maps, thus its content will only be computed once. - - The function will also change the name of the outer map, it will append - `_blocked` to it. + single dimension, that is commonly called "k". All dimensions except `k` are + unaffected by this transformation. In the outer Map will be replace the `k` + range, currently `k = 0:N`, with `__coarse_k = 0:N:B`, where `N` is the + original size of the range and `B` is the blocking size. The transformation + will then create an inner sequential map with `k = __coarse_k:(__coarse_k + B)`. + + What makes this transformation different from simple blocking, is that + the inner map will not just be inserted right after the outer Map. + Instead the transformation will first identify all nodes that does not depend + on the blocking parameter and relocate them between the outer and inner map. + Thus these operations will only be performed once, per inner loop. + + Args: + blocking_size: The size of the block, denoted as `B` above. + blocking_parameter: On which parameter should we block. Todo: - - Allow that independent nodes does not need to have direct connection - to the map, instead it is enough that they only have connections to - independent nodes. - - Investigate if the iteration should always start at zero. + - Modify the inner map such that it always starts at zero. + - Allow more than one parameter on which we block. """ blocking_size = dace_properties.Property( @@ -56,30 +53,30 @@ class KBlocking(dace_transformation.SingleStateTransformation): allow_none=True, desc="Size of the inner k Block.", ) - block_dim = dace_properties.Property( + blocking_parameter = dace_properties.Property( dtype=str, allow_none=True, - desc="Which dimension should be blocked (must be an exact match).", + desc="Name of the iteration variable on which to block (must be an exact match).", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, blocking_size: Optional[int] = None, - block_dim: Optional[Union[gtx_common.Dimension, str]] = None, + blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, ) -> None: super().__init__() - if isinstance(block_dim, gtx_common.Dimension): - block_dim = gtx_dace_fieldview_util.get_map_variable(block_dim) - if block_dim is not None: - self.block_dim = block_dim + if isinstance(blocking_parameter, gtx_common.Dimension): + blocking_parameter = gtx_dace_fieldview_util.get_map_variable(blocking_parameter) + if blocking_parameter is not None: + self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size @classmethod def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + return [dace.sdfg.utils.node_path_graph(cls.outer_entry)] def can_be_applied( self, @@ -92,31 +89,28 @@ def can_be_applied( The test involves: - The map must be at the top level. - - The map shall not be serial. - The block dimension must be present (exact string match). - The map range must have step size of 1. - The partition must exists (see `partition_map_output()`). """ - if self.block_dim is None: + if self.blocking_parameter is None: raise ValueError("The blocking dimension was not specified.") elif self.blocking_size is None: raise ValueError("The blocking size was not specified.") - map_entry: dace_nodes.MapEntry = self.map_entry - map_params: list[str] = map_entry.map.params - map_range: dace_subsets.Range = map_entry.map.range - block_var: str = self.block_dim + outer_entry: dace_nodes.MapEntry = self.outer_entry + map_params: list[str] = outer_entry.map.params + map_range: dace_subsets.Range = outer_entry.map.range + block_var: str = self.blocking_parameter scope = graph.scope_dict() - if scope[map_entry] is not None: + if scope[outer_entry] is not None: return False - if block_var not in map_entry.map.params: - return False - if map_entry.map.schedule == dace.dtypes.ScheduleType.Sequential: + if block_var not in outer_entry.map.params: return False if map_range[map_params.index(block_var)][2] != 1: return False - if self.partition_map_output(map_entry, block_var, graph, sdfg) is None: + if self.partition_map_output(graph, sdfg) is None: return False return True @@ -130,49 +124,53 @@ def apply( Performs the operation described in the doc string. """ - outer_entry: dace_nodes.MapEntry = self.map_entry - outer_exit: dace_nodes.MapExit = graph.exit_node(outer_entry) - outer_map: dace_nodes.Map = outer_entry.map - map_range: dace_subsets.Range = outer_entry.map.range - map_params: list[str] = outer_entry.map.params - - # This is the name of the iterator we coarsen - block_var: str = self.block_dim - block_idx = map_params.index(block_var) - - # This is the name of the iterator that we use in the outer map for the - # blocked dimension - coarse_block_var = "__coarse_" + block_var # Now compute the partitions of the nodes. - independent_nodes, dependent_nodes = self.partition_map_output( # type: ignore[misc] # Guaranteed to be not `None`. - outer_entry, block_var, graph, sdfg + independent_nodes, dependent_nodes = self.partition_map_output(graph, sdfg) # type: ignore[misc] # Guaranteed to be not `None`. + + # Modify the outer map and create the inner map. + (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps(graph) + + # Reconnect the edges + self._rewire_map_scope( + outer_entry=outer_entry, + outer_exit=outer_exit, + inner_entry=inner_entry, + inner_exit=inner_exit, + independent_nodes=independent_nodes, + dependent_nodes=dependent_nodes, + state=graph, + sdfg=sdfg, ) - # Generate the sequential inner map - rng_start = map_range[block_idx][0] - rng_stop = map_range[block_idx][1] - inner_label = f"inner_{outer_map.label}" - inner_range = { - block_var: dace_subsets.Range.from_string( - f"({coarse_block_var} * {self.blocking_size} + {rng_start}):min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" - ) - } - inner_entry, inner_exit = graph.add_map( - name=inner_label, - ndrange=inner_range, - schedule=dace.dtypes.ScheduleType.Sequential, - ) + @staticmethod + def _rewire_map_scope( + outer_entry: dace_nodes.MapEntry, + outer_exit: dace_nodes.MapExit, + inner_entry: dace_nodes.MapEntry, + inner_exit: dace_nodes.MapExit, + independent_nodes: set[dace_nodes.Node], + dependent_nodes: set[dace_nodes.Node], + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + """Rewire the edges inside the scope defined by the outer map. - # TODO(phimuell): Investigate if we want to prevent unrolling here + The function assumes that the outer and inner map were obtained by a call + to `_prepare_inner_outer_maps()`. The function will now rewire the connections of these + nodes such that the dependent nodes are inside the scope of the inner map, + while the independent nodes remain outside. - # Now we modify the properties of the outer map. - coarse_block_range = dace_subsets.Range.from_string( - f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" - ).ranges[0] - outer_map.params[block_idx] = coarse_block_var - outer_map.range[block_idx] = coarse_block_range - outer_map.label = f"{outer_map.label}_blocked" + Args: + outer_entry: The entry node of the outer map. + outer_exit: The exit node of the outer map. + inner_entry: The entry node of the inner map. + inner_exit: The exit node of the inner map. + independent_nodes: The set of independent nodes. + dependent_nodes: The set of dependent nodes. + state: The state of the map. + sdfg: The SDFG we operate on. + """ # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() @@ -181,7 +179,7 @@ def apply( # _output_ edges have to go through the new inner map and the Memlets need # modifications, because of the block parameter. for independent_node in independent_nodes: - for out_edge in graph.out_edges(independent_node): + for out_edge in state.out_edges(independent_node): edge_dst: dace_nodes.Node = out_edge.dst relocated_nodes.add(edge_dst) @@ -197,13 +195,13 @@ def apply( # new inner map entry iterate over the blocked variable. new_map_conn = inner_entry.next_connector() dace_helpers.redirect_edge( - state=graph, + state=state, edge=out_edge, new_dst=inner_entry, new_dst_conn="IN_" + new_map_conn, ) # TODO(phimuell): Check if there might be a subset error. - graph.add_edge( + state.add_edge( inner_entry, "OUT_" + new_map_conn, out_edge.dst, @@ -216,7 +214,7 @@ def apply( # Now we handle the dependent nodes, they differ from the independent nodes # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. for dependent_node in dependent_nodes: - for in_edge in graph.in_edges(dependent_node): + for in_edge in state.in_edges(dependent_node): edge_src: dace_nodes.Node = in_edge.src # Since the independent nodes were already processed, and they process @@ -240,23 +238,21 @@ def apply( assert ( edge_src is outer_entry ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." - dace_helpers.redirect_edge(state=graph, edge=in_edge, new_src=inner_entry) + dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) continue # Because of the definition of a dependent node and the processing # order, their incoming edges either point to the outer map or # are already handled. - if edge_src is not outer_entry: - sdfg.view() assert ( edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}' || {dependent_node} // {edge_src in independent_nodes}." + ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." edge_conn: str = in_edge.src_conn[4:] # Must be before the handling of the modification below # Note that this will remove the original edge from the SDFG. dace_helpers.redirect_edge( - state=graph, + state=state, edge=in_edge, new_src=inner_entry, new_src_conn="OUT_" + edge_conn, @@ -267,7 +263,7 @@ def apply( # We have found this edge multiple times already. # To ensure that there is no error, we will create a new # Memlet that reads the whole array. - piping_edge = next(graph.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) + piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) data_name = piping_edge.data.data piping_edge.data = dace.Memlet.from_array( data_name, sdfg.arrays[data_name], piping_edge.data.wcr @@ -276,7 +272,7 @@ def apply( else: # This is the first time we found this connection. # so we just create the edge. - graph.add_edge( + state.add_edge( outer_entry, "OUT_" + edge_conn, inner_entry, @@ -288,8 +284,8 @@ def apply( # In certain cases it might happen that we need to create an empty # Memlet between the outer map entry and the inner one. - if graph.in_degree(inner_entry) == 0: - graph.add_edge( + if state.in_degree(inner_entry) == 0: + state.add_edge( outer_entry, None, inner_entry, @@ -300,15 +296,15 @@ def apply( # Handle the Map exits # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. - for in_edge in graph.in_edges(outer_exit): + for in_edge in state.in_edges(outer_exit): edge_conn = in_edge.dst_conn[3:] dace_helpers.redirect_edge( - state=graph, + state=state, edge=in_edge, new_dst=inner_exit, new_dst_conn="IN_" + edge_conn, ) - graph.add_edge( + state.add_edge( inner_exit, "OUT_" + edge_conn, outer_exit, @@ -319,12 +315,73 @@ def apply( inner_exit.add_out_connector("OUT_" + edge_conn) # TODO(phimuell): Use a less expensive method. - dace.sdfg.propagation.propagate_memlets_state(sdfg, graph) + dace.sdfg.propagation.propagate_memlets_state(sdfg, state) + + def _prepare_inner_outer_maps( + self, + state: dace.SDFGState, + ) -> tuple[ + tuple[dace_nodes.MapEntry, dace_nodes.MapExit], + tuple[dace_nodes.MapEntry, dace_nodes.MapExit], + ]: + """Prepare the maps for the blocking. + + The function modifies the outer map, `self.outer_entry`, by replacing the + blocking parameter, `self.blocking_parameter`, with a coarsened version + of it. In addition the function will then create the inner map, that + iterates over the blocking parameter, and these bounds are determined + by the coarsened blocking parameter of the outer map. + + Args: + state: The state on which we operate. + + Return: + The function returns a tuple of length two, the first element is the + modified outer map and the second element is the newly created + inner map. Each element consist of a pair containing the map entry + and map exit nodes of the corresponding maps. + """ + outer_entry: dace_nodes.MapEntry = self.outer_entry + outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) + outer_map: dace_nodes.Map = outer_entry.map + outer_range: dace_subsets.Range = outer_entry.map.range + outer_params: list[str] = outer_entry.map.params + blocking_parameter_dim = outer_params.index(self.blocking_parameter) + + # This is the name of the iterator that we use in the outer map for the + # blocked dimension + coarse_block_var = "__coarse_" + self.blocking_parameter + + # Generate the sequential inner map + rng_start = outer_range[blocking_parameter_dim][0] + rng_stop = outer_range[blocking_parameter_dim][1] + inner_label = f"inner_{outer_map.label}" + inner_range = { + self.blocking_parameter: dace_subsets.Range.from_string( + f"({coarse_block_var} * {self.blocking_size} + {rng_start})" + + ":" + + f"min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + ) + } + inner_entry, inner_exit = state.add_map( + name=inner_label, + ndrange=inner_range, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + # TODO(phimuell): Investigate if we want to prevent unrolling here + + # Now we modify the properties of the outer map. + coarse_block_range = dace_subsets.Range.from_string( + f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + ).ranges[0] + outer_map.params[blocking_parameter_dim] = coarse_block_var + outer_map.range[blocking_parameter_dim] = coarse_block_range + + return ((outer_entry, outer_exit), (inner_entry, inner_exit)) def partition_map_output( self, - map_entry: dace_nodes.MapEntry, - block_param: str, state: dace.SDFGState, sdfg: dace.SDFG, ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: @@ -334,22 +391,22 @@ def partition_map_output( - The independent nodes `\mathcal{I}`: These are the nodes, whose output does not depend on the blocked dimension. These nodes can be relocated between the outer and inner map. - Nodes in these set does not necessarily have a direct edge to `map_entry`. - However, the exists a path from `map_entry` to any node in this set. + Nodes in these set does not necessarily have a direct edge to map entry. + However, the exists a path from `outer_entry` to any node in this set. - The dependent nodes `\mathcal{D}`: These are the nodes, whose output depend on the blocked dimension. Thus they can not be relocated between the two maps, but will remain - inside the inner scope. All these nodes have at least one edge to `map_entry`. + inside the inner scope. All these nodes have at least one edge to map entry. The function uses an iterative method to compute the set of independent nodes. In each iteration the function will look classify all nodes that have an - incoming edge originating either at `map_entry` or from a node that was + incoming edge originating either at outer_entry or from a node that was already classified as independent. This is repeated until no new independent nodes are found. This means that independent nodes does not necessarily have - a direct connection to `map_entry`. + a direct connection to map entry. - The dependent nodes on the other side always have a direct edge to `map_entry`. - As they are the set of nodes that are adjacent to `map_entry` but are not + The dependent nodes on the other side always have a direct edge to outer_entry. + As they are the set of nodes that are adjacent to outer_entry but are not independent. For the sake of arguments, all nodes, except the map entry and exit nodes, @@ -359,8 +416,6 @@ def partition_map_output( In case the function fails to compute the partition `None` is returned. Args: - map_entry: The map entry node. - block_param: The Map variable that should be blocked. state: The state on which we operate. sdfg: The SDFG in which we operate on. @@ -369,14 +424,16 @@ def partition_map_output( - Currently this function only considers the input Memlets and the `used_symbol` properties of a Tasklet to determine if a Tasklet is dependent. """ + outer_entry: dace_nodes.MapEntry = self.outer_entry + blocking_parameter: str = self.blocking_parameter independent_nodes: set[dace_nodes.Node] = set() # `\mathcal{I}` while True: # Find all the nodes that we have to classify in this iteration. - # - All nodes adjacent to `map_entry` + # - All nodes adjacent to `outer_entry` # - All nodes adjacent to independent nodes. nodes_to_classify: set[dace_nodes.Node] = { - edge.dst for edge in state.out_edges(map_entry) + edge.dst for edge in state.out_edges(outer_entry) } for independent_node in independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) @@ -387,8 +444,8 @@ def partition_map_output( for node_to_classify in nodes_to_classify: class_res = self.classify_node( node_to_classify=node_to_classify, - map_entry=map_entry, - block_param=block_param, + outer_entry=outer_entry, + blocking_parameter=blocking_parameter, independent_nodes=independent_nodes, state=state, sdfg=sdfg, @@ -405,18 +462,18 @@ def partition_map_output( break # After the independent set is computed compute the set of dependent nodes - # as the set of all nodes adjacent to `map_entry` that are not dependent. + # as the set of all nodes adjacent to `outer_entry` that are not dependent. dependent_nodes: set[dace_nodes.Node] = { - edge.dst for edge in state.out_edges(map_entry) if edge.dst not in independent_nodes + edge.dst for edge in state.out_edges(outer_entry) if edge.dst not in independent_nodes } return (independent_nodes, dependent_nodes) + @staticmethod def classify_node( - self, node_to_classify: dace_nodes.Node, - map_entry: dace_nodes.MapEntry, - block_param: str, + outer_entry: dace_nodes.MapEntry, + blocking_parameter: str, independent_nodes: set[dace_nodes.Node], state: dace.SDFGState, sdfg: dace.SDFG, @@ -426,11 +483,11 @@ def classify_node( The general rule to classify if a node is independent are: - The node must be a Tasklet or an AccessNode, in all other cases the partition does not exist. - - `free_symbols` of the nodes shall not contain the `block_param`. + - `free_symbols` of the nodes shall not contain the `blocking_parameter`. - All incoming _empty_ edges must be connected to the map entry. - A node either has only empty Memlets or none of them. - - Incoming Memlets does not depend on the `block_param`. - - All incoming edges must start either at `map_entry` or at dependent nodes. + - Incoming Memlets does not depend on the `blocking_parameter`. + - All incoming edges must start either at `outer_entry` or at dependent nodes. - All output Memlets are non empty. Returns: @@ -442,11 +499,10 @@ def classify_node( Args: node_to_classify: The node that should be classified. - map_entry: The entry of the map that should be partitioned. - block_param: The iteration parameter that should be blocked. + outer_entry: The entry of the map that should be partitioned. + blocking_parameter: The iteration parameter that should be blocked. independent_nodes: The set of nodes that was already classified as - independent, in case the node is classified as independent the set is - updated. + independent, in which case it is added to `independent_nodes`. state: The state containing the map. sdfg: The SDFG that is processed. """ @@ -488,7 +544,7 @@ def classify_node( for in_edge in in_edges: if in_edge.data.is_empty(): found_empty_edges = True - if in_edge.src is not map_entry: + if in_edge.src is not outer_entry: # TODO(phimuell): Lift this restriction. return None else: @@ -499,7 +555,7 @@ def classify_node( return None assert ( found_empty_edges or found_nonempty_edges - ), f"Node '{node_to_classify}' inside '{map_entry}' without an input connection." + ), f"Node '{node_to_classify}' inside '{outer_entry}' without an input connection." # Requiring that all output Memlets are non empty implies, because we are # inside a scope, that there exists an output. @@ -512,7 +568,7 @@ def classify_node( # Test if the body of the Tasklet depends on the block variable. if ( isinstance(node_to_classify, dace_nodes.Tasklet) - and block_param in node_to_classify.free_symbols + and blocking_parameter in node_to_classify.free_symbols ): return False @@ -537,12 +593,12 @@ def classify_node( # If a subset needs the block variable then the node is not independent # but dependent. - if any(block_param in subset.free_symbols for subset in subsets_to_inspect): + if any(blocking_parameter in subset.free_symbols for subset in subsets_to_inspect): return False - # The edge must either originate from `map_entry` or from an independent + # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is map_entry or in_edge.src in independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in independent_nodes): return False # Loop ended normally, thus we classify the node as independent. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py index d2e064386d..b1b7b812d6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py @@ -147,7 +147,7 @@ def test_only_dependent(): # Apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) @@ -211,7 +211,7 @@ def test_intermediate_access_node(): # Apply the transformation. sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) @@ -249,7 +249,7 @@ def test_chained_access() -> None: # Apply the transformation. ret = sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) From 57ae4eaa6defd513280d341807b0b5601be2683c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 27 Aug 2024 11:46:29 +0200 Subject: [PATCH 212/235] Added an ADRF for the DaCe parts of the toolchain. --- ...Canonical_SDFG_in_GT4Py_Transformations.md | 130 ++++++++++++++++++ docs/development/ADRs/Index.md | 3 +- 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md new file mode 100644 index 0000000000..6b321a452e --- /dev/null +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -0,0 +1,130 @@ +--- +tags: [backend, dace, optimization] +--- + +# Canonical Form of an SDFG in GT4Py (Especially for Optimizations) + +- **Status**: valid +- **Authors**: Philip Müller (@philip-paul-mueller) +- **Created**: 2024-08-27 + +In the context of the implementation of the new DaCe fieldview we decided about a particular form of the SDFG. +Their main intent is to reduce the complexity of the GT4Py specific transformations. + +## Context + +The canonical form that is outlined in this document was mainly designed from the perspective of the optimization pipeline. +Thus it emphasizes a form that is can be handled in a simple and efficient way by a transformation. +In the pipeline we distinguishes between + +- Intrastate optimization: The optimization of the data flow within states. +- Interstate optimization: The optimization between states, this are transformations that are _intended_ to _reduce_ the number of states. + +The current (GT4Py) pipeline mainly focus on intrastate optimization and relays on DaCe, especially its simplify pass, for interstate optimizations. + +## Decision + +The canonical form is defined by several rules that affect different aspects of an SDFG and what a transformation can assume. +This allows to simplify the implementation of certain transformations. + +#### General Aspects + +The following rules, especially affects transformations and how they operate: + +1. Intrastate transformation and interstate transformations must run separately and can not be mixed in the same (DaCe) pipeline. + + - [Rational]: As a consequence the number of "interstate transients" (transients that are used in multiple states) remains constant during intrastate transformations. + - [Note 1]: It is allowed to run them after one another, as long as they are strictly separated. + - [Note 2]: It is allowed that _intrastate_ transformation act in a way to allow state fusion by later intrastate transformations. + - [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2. + +2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed, the only valid way to call simplify is to call the `gt_simplify()` function provided by GT4Py. + - [Rational]: It was observed that some sub passes in simplify have a negative impact and that additional passes might be needed in the future. + By only using a single function later modifications to simplify are easy. + - [Note]: One issue is that the remove redundant array transformation is not able to handle all cases. + +#### Global Memory + +The only restriction we impose on global memory is: + +3. The same global memory is allowed to be used as input and output at the same time, iff the output depends _elementwise_ on the input. + - [Rational 1]: Allows to remove double buffering, that DaCe may not remove, see also rule 2. + - [Rational 2]: This formulation allows to write expressions such as `a += 1`, with only memory for `a`. + Phrased more technically using global memory for input and output is allowed iff the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent. + - [Note]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. + +#### State Machine + +For the SDFG state machine we assume that: + +4. An interstate edge can only access scalars, i.e. use them in their assignment or condition expressions, but not arrays, even if they have shape `(1,)`. + + - [Rational]: If an array is also used in interstate edges it became very tedious to verify if the array could be removed or not. + - [Note]: Running simplify might actually result in the violation of this rule, see note of rule 9. + +5. The state graph does not contain any cycles, i.e. the implementation of a for/while loop using states is not allowed, the new loop construct or serial maps must be used in that case. + - [Rational]: This is a simplification that makes it much simpler to define "later in the computation" means as we will never have a cycle. + - [Note]: Currently the code generator does not support the `LoopRegion` construct and it is transformed to a state machine. + +#### Transients + +The rules we impose on transients are a bit more complicated, however, while sounding restrictive, they are very permissive. +It is important that these rules only have to be met after after simplify was called once on the SDFG: + +6. Downstream of a write access, i.e. in all states that follows the state the access node is located in, there are no other access nodes that are used to write to the same array. + + - [Rational 1]: This rule together with rule 7 and 8 essentially boils down to ensure that the assignment in the SDFG follows SSA style, while allowing for expressions such as: + + ```python + if cond: + a = true_branch() + else: + a = false_branch() + ``` + + (**NOTE:** This could also be done with references, however, they are strongly discouraged.) + + - [Rational 2]: This still allows reductions with WCR as they write to the same access node and loops, whose body modifies a transient that outlives the loop body, as they use the same access node. + +7. It is _recommended_ that a write access node should only have one incoming edge. + + - [Rational]: This case is handled poorly by some DaCe transformations, thus we should avoid them as much as possible. + +8. No two access nodes in a state can refer to the same array. + + - [Rational]: Together with rule 5 this guarantees SSA style. + - [Note]: An SDFG can still be constructed using different access node for the same underlying data; simplify will combine them. + +9. Every access node that reads from an array (having an outgoing edge) that was not written to in the same state must be a source node. + + - [Rational]: Together with rule 1, 4, 5, 6, 7 and 8 this simplifies the check if a transient can be safely removed or if it is used somewhere else. + These rules guarantee that the number of "interstate transients" remains constant and these set is given by the _set of source nodes and all access nodes that have an outgoing degree larger than one_. + - [Note]: To prevent some issues caused by the violation of rule 4 by simplify, this set is extended with the transient sink nodes and all scalars. + Excess interstate transients, that will be kept alive that way, will be removed by later calls to simplify. + +10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should be _preferable_ `dace.dtypes.StorageType.Register`. + - [Rational 1]: Makes optimizations operating inside a maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. + - [Rational 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. + +#### Maps + +For maps we assume the following: + +11. The names of map variables (iteration variable) follow the following pattern. + + - 11.1: All map variables iterating over the same dimension (disregarding the actual range), have the same deterministic name, that includes the `gtx.Dimension.value` string. + - 11.2: The name of horizontal dimensions (`kind` attribute) always end in `__gtx_horizontal`. + - 11.3: The name of vertical dimensions (`kind` attribute) always end in `__gtx_vertical`. + - 11.4: The name of local dimensions always ends in `__gtx_localdim`. + - 11.5: No transformation is allowed to modify the name of an iteration variable that follows rules 11.2, 11.3 or 11.4. + - [Rational]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12. + +12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range. + - [Rational 1]: Because of rule 11 we will only fuse maps that actually makes sense to fuse. + - [Rational 2]: This allows to fuse maps without performing a renaming on the map variables. + - [Note]: This rule might be dropped in the future. + +## Consequences + +The rules outlined above impose a certain form of an SDFG. +Most of these rules are designed to ensure that the SDFG follows SSA style and to simplify transformations, especially making validation checks simple, while imposing a minimal number of restrictions. diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 072e6dc2ea..ccf2e744f4 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -38,7 +38,7 @@ _None_ ### Transformations -_None_ +- [0018 - Canonical Form of an SDFG in GT4Py (Especially for Optimizations)](0018-Canonical_SDFG_in_GT4Py_Transformations.md) ### Backends and Code Generation @@ -47,6 +47,7 @@ _None_ - [0008 - Mapping Domain to Cpp Backend](0008-Mapping_Domain_to_Cpp-Backend.md) - [0016 - Multiple Backends and Build Systems](0016-Multiple-Backends-and-Build-Systems.md) - [0017 - Toolchain Configuration](0017-Toolchain-Configuration.md) +- [0018 - Canonical Form of an SDFG in GT4Py (Especially for Optimizations)](0018-Canonical_SDFG_in_GT4Py_Transformations.md) ### Python Integration From 4d2e941336204cd41a0f7cfe7ba0032692cd4b9f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 27 Aug 2024 14:11:28 +0200 Subject: [PATCH 213/235] Made a note in to the map fusion files that we will delete them as soon as the transformations that ship with DaCe are updated. --- .../dace_fieldview/transformations/map_fusion_helper.py | 7 ++++++- .../dace_fieldview/transformations/map_serial_fusion.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index ced6f41ff2..632b3efe84 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -6,7 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements Helper functionaliyies for map fusion""" +"""Implements helper functions for the map fusion transformations. + +Note: + After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements + a better map fusion transformation is merged, this file will be deleted. +""" import functools import itertools diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index a07b613ab1..bca5aa2268 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -6,7 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements the serial map fusing transformation.""" +"""Implements the serial map fusing transformation. + +Note: + After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements + a better map fusion transformation is merged, this file will be deleted. +""" import copy from typing import Any, Union From 243bc8edafbb5a6c5b4cefb4958fa593db2d4e96 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 27 Aug 2024 14:12:26 +0200 Subject: [PATCH 214/235] Removed all reference to the HackMD file and changed them with references to the ADR. Only the map fusion contains such references as we will delete anyway. --- .../runners/dace_fieldview/transformations/__init__.py | 4 ++-- .../runners/dace_fieldview/transformations/auto_opt.py | 5 +++++ .../transformations/map_fusion_helper.py | 10 +++++----- .../dace_fieldview/transformations/map_orderer.py | 4 +++- .../dace_fieldview/transformations/map_promoter.py | 5 +++-- 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 68e69f5e5d..6106d0c4b2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -8,8 +8,8 @@ """Transformation and optimization pipeline for the DaCe backend in GT4Py. -Please also see [this HackMD document](https://hackmd.io/@gridtools/rklwk4OIR#Requirements-on-SDFG) -that explains the general structure and requirements on the SDFG. +Please also see [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) +that explains the general structure and requirements on the SDFGs. """ from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index a3b0e96fba..e773ccedeb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -100,6 +100,11 @@ def gt_auto_optimize( ) -> dace.SDFG: """Performs GT4Py specific optimizations on the SDFG in place. + This function expects that the input SDFG follows the principles that are + outlined in [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). + It is important to notice, that if `reuse_transients` is active then the + optimized SDFG no longer conforms to these rules. + The auto optimization works in different phases, that focuses each on different aspects of the SDFG. The initial SDFG is assumed to have a very large number of rather simple Maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 632b3efe84..4b8231eaa6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -34,7 +34,8 @@ class MapFusionHelper(dace_transformation.SingleStateTransformation): """Contains common part of the fusion for parallel and serial Map fusion. - The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + The transformation assumes that the SDFG obeys the principals outlined in + [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). The main advantage of this structure is, that it is rather easy to determine if a transient is used anywhere else. This check, performed by `is_interstate_transient()`. It is further speeded up by cashing some computation, @@ -294,12 +295,11 @@ def is_interstate_transient( transient: The transient that should be checked. sdfg: The SDFG containing the array. state: If given the state the node is located in. - - Note: - This function build upon the structure of the SDFG that is outlined - in the HackMD document. """ + # The following builds upon the HACK MD document and not on ADR0018. + # Therefore the numbers are slightly different, but both documents + # essentially describes the same SDFG. # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) # the set of such transients is partially given by all source access dace_nodes. # Because of rule 3 we also include all scalars in this set, as an over diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index ee6b2c43fa..4b34dd6adc 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -34,7 +34,9 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): is supposed to have stride 1. Note: - The transformation does follow the rules outlines [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + The transformation does follow the rules outlines in + [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) + especially rule 11, regarding the names. Todo: - Extend that different dimensions can be specified to be leading diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 885ba4e09a..fc7f06f630 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -33,8 +33,9 @@ class BaseMapPromoter(dace_transformation.SingleStateTransformation): In order to properly work, the parameters of "source map" must be a strict superset of the ones of "map to promote". Furthermore, this transformation - builds upon the structure defined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). - Thus it only checks the name of the parameters. + builds upon the structure defined [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) + especially rule 11, regarding the names. Thus it only checks the name of + the parameters to determine if it should perform the promotion or not. To influence what to promote the user must implement the `map_to_promote()` and `source_map()` function. They have to return the map entry node. From a74a54dda5ab029b12789c7bbedfa2d38c0e8ab7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 27 Aug 2024 16:24:04 +0200 Subject: [PATCH 215/235] Updated the map promotion. --- .../transformations/__init__.py | 7 +- .../transformations/gpu_utils.py | 100 ++++++++++------- .../transformations/map_fusion_helper.py | 2 + .../transformations/map_promoter.py | 106 ++++++++++++++---- 4 files changed, 150 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 6106d0c4b2..b698781b1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -13,12 +13,7 @@ """ from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify -from .gpu_utils import ( - GPUSetBlockSize, - SerialMapPromoterGPU, - gt_gpu_transformation, - gt_set_gpu_blocksize, -) +from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize from .k_blocking import KBlocking from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index b44c545499..667ae3ccf0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -94,7 +94,7 @@ def gt_gpu_transformation( # fuse everything together it can find. # TODO(phimuell): Fix the issue described above. sdfg.apply_transformations_once_everywhere( - gtx_transformations.SerialMapPromoterGPU(), + TrivialGPUMapPromoter(), validate=False, validate_all=False, ) @@ -286,8 +286,8 @@ def apply( @dace_properties.make_properties -class SerialMapPromoterGPU(dace_transformation.SingleStateTransformation): - """Serial Map promoter for empty Maps in case of trivial Maps. +class TrivialGPUMapPromoter(dace_transformation.SingleStateTransformation): + """Serial Map promoter for empty GPU maps. In CPU mode a Tasklet can be outside of a map, however, this is not possible in GPU mode. For this reason DaCe wraps such Tasklets in a @@ -296,22 +296,27 @@ class SerialMapPromoterGPU(dace_transformation.SingleStateTransformation): that they can be fused with downstream maps. Note: - This transformation must be run after the GPU Transformation. - - Todo: + - This transformation should not be run on its own, instead it + is run within the context of `gt_gpu_transformation()`. + - This transformation must be run after the GPU Transformation. + - Currently the transformation does not do the fusion on its own. + Instead map fusion must be run afterwards. - The transformation assumes that the upper Map is a trivial Tasklet. Which should be the majority of all cases. - - Combine this transformation such that it can do serial fusion on its own. """ # Pattern Matching - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + trivial_map_exit = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + second_map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + return [ + dace.sdfg.utils.node_path_graph( + cls.trivial_map_exit, cls.access_node, cls.second_map_entry + ) + ] def can_be_applied( self, @@ -322,37 +327,54 @@ def can_be_applied( ) -> bool: """Tests if the promotion is possible. - The function tests: - - If the top map is a trivial map. - - If a valid partition exists that can be fused at all. + The tests includes: + - Schedule of the maps. + - If the map is trivial. + - If the trivial map was not used to define a symbol. + - Intermediate access node can only have in and out degree of 1. + - The trivial map exit can only have one output. """ - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_1: dace_nodes.Map = map_exit_1.map - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - - # Check if the first map is trivial. - if len(map_1.params) != 1: - return False - if map_1.range.num_elements() != 1: + trivial_map_exit: dace_nodes.MapExit = self.trivial_map(state=graph, sdfg=sdfg) + trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) + trivial_map: dace_nodes.Map = trivial_map_entry.map + second_map: dace_nodes.Map = self.second_map_entry.map + access_node: dace_nodes.AccessNode = self.access_node + + # The kind of maps we are interested only have one parameter. + if len(trivial_map.params) != 1: return False # Check if it is a GPU map - if map_1.schedule not in [ - dace.dtypes.ScheduleType.GPU_Device, - dace.dtypes.ScheduleType.GPU_Default, - ]: + for map_to_check in [trivial_map, second_map]: + if map_to_check.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Check if the map is trivial. + for rng in trivial_map.range.ranges: + if rng[0] != rng[1]: + return False + + # Now we have to ensure that the symbol is not used inside the scope of the + # map, if it is, then the symbol is just there to define a symbol. + scope_view = graph.scope_subgraph( + trivial_map_entry, + include_entry=False, + include_exit=False, + ) + if any(map_param in scope_view.free_symbols for map_param in trivial_map.params): return False - # Check if the partition exists, if not promotion to fusing is pointless. - # TODO(phimuell): Find the proper way of doing it. - serial_fuser = gtx_transformations.SerialMapFusion(only_toplevel_maps=True) - output_partition = serial_fuser.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - if output_partition is None: + # ensuring that the trivial map exit and the intermediate node have degree + # one is a cheap way to ensure that the map can be merged into the + # second map. + if graph.in_degree(access_node) != 1: + return False + if graph.out_degree(access_node) != 1: + return False + if graph.out_degree(trivial_map_exit) != 1: return False return True @@ -363,8 +385,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non The function essentially copies the parameters and the ranges from the bottom map to the top one. """ - map_1: dace_nodes.Map = self.map_exit1.map - map_2: dace_nodes.Map = self.map_entry2.map + trivial_map: dace_nodes.Map = self.trivial_map_exit.map + second_map: dace_nodes.Map = self.second_map_entry.map - map_1.params = copy.deepcopy(map_2.params) - map_1.range = copy.deepcopy(map_2.range) + trivial_map.params = copy.deepcopy(second_map.params) + trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index 4b8231eaa6..ec33e7ea63 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -133,6 +133,8 @@ def can_be_fused( return False # We will now check if there exists a "remapping" that we can use. + # NOTE: The serial map promoter depends on the fact that this is the + # last check. if not self.map_parameter_compatible( map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg ): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index fc7f06f630..f687e08cb6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy from typing import Any, Mapping, Optional, Sequence, Union import dace @@ -16,6 +17,10 @@ ) from dace.sdfg import nodes as dace_nodes +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + @dace_properties.make_properties class BaseMapPromoter(dace_transformation.SingleStateTransformation): @@ -34,8 +39,10 @@ class BaseMapPromoter(dace_transformation.SingleStateTransformation): In order to properly work, the parameters of "source map" must be a strict superset of the ones of "map to promote". Furthermore, this transformation builds upon the structure defined [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) - especially rule 11, regarding the names. Thus it only checks the name of - the parameters to determine if it should perform the promotion or not. + especially rule 11, regarding the names. Thus the function uses the names + of the map parameters to determine what a local, horizontal, vertical or + unknown dimension is. It also uses rule 12, therefore it will not perform + a renaming and iteration variables must be a match. To influence what to promote the user must implement the `map_to_promote()` and `source_map()` function. They have to return the map entry node. @@ -43,9 +50,12 @@ class BaseMapPromoter(dace_transformation.SingleStateTransformation): Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. - promote_vertical: If `True` promote vertical dimensions; `True` by default. - promote_local: If `True` promote local dimensions; `True` by default. - promote_horizontal: If `True` promote horizontal dimensions; `False` by default. + promote_vertical: If `True` promote vertical dimensions, i.e. add them + to the map to promote; `True` by default. + promote_local: If `True` promote local dimensions, i.e. add them to the + map to promote; `True` by default. + promote_horizontal: If `True` promote horizontal dimensions, i.e. add + them to the map to promote; `False` by default. promote_all: Do not impose any restriction on what to promote. The only reasonable value is `True` or `None`. @@ -173,13 +183,13 @@ def can_be_applied( # to ensure that the symbols are the same and all. But this is guaranteed by # the nature of this transformation (single state). if self.only_inner_maps or self.only_toplevel_maps: - scopeDict: Mapping[dace_nodes.Node, Union[dace_nodes.Node, None]] = graph.scope_dict() - if self.only_inner_maps and (scopeDict[map_to_promote_entry] is None): + scope_dict: Mapping[dace_nodes.Node, Union[dace_nodes.Node, None]] = graph.scope_dict() + if self.only_inner_maps and (scope_dict[map_to_promote_entry] is None): return False - if self.only_toplevel_maps and (scopeDict[map_to_promote_entry] is not None): + if self.only_toplevel_maps and (scope_dict[map_to_promote_entry] is not None): return False - # Test if the map ranges are compatible with each other. + # Test if the map ranges and parameter are compatible with each other missing_map_parameters: list[str] | None = self.missing_map_params( map_to_promote=map_to_promote, source_map=source_map, @@ -201,6 +211,9 @@ def can_be_applied( if not dimension_identifier: return False for missing_map_param in missing_map_parameters: + # Check if all missing parameter match a specified pattern. Note + # unknown iteration parameter, such as `__hansi_meier` will be + # rejected and can not be promoted. if not any( missing_map_param.endswith(dim_identifier) for dim_identifier in dimension_identifier @@ -254,6 +267,7 @@ def missing_map_params( The returned sequence is empty if they are already have the same parameters. The function will return `None` is promoting is not possible. + By setting `be_strict` to `False` the function will only check the names. Args: map_to_promote: The map that should be promoted. @@ -330,25 +344,77 @@ def can_be_applied( permissive: bool = False, ) -> bool: """Tests if the Maps really can be fused.""" - from .map_serial_fusion import SerialMapFusion + # Test if the promotion could be done. if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - # Check if the partition exists, if not promotion to fusing is pointless. - # TODO(phimuell): Find the proper way of doing it. - serial_fuser = SerialMapFusion(only_toplevel_maps=True) - output_partition = serial_fuser.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=self.exit_first_map, - map_entry_2=self.entry_second_map, - ) - if output_partition is None: + # Test if after the promotion the maps could be fused. + if not self._test_if_promoted_maps_can_be_fused(graph, sdfg): return False return True + def _test_if_promoted_maps_can_be_fused( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """This function checks if the promoted maps can be fused by map fusion. + + This function assumes that `self.can_be_applied()` returned `True`. + + Args: + state: The state in which we operate. + sdfg: The SDFG we process. + + Note: + The current implementation uses a very hacky way to test this. + + Todo: + Find a better way of doing it. + """ + first_map_exit: dace_nodes.MapExit = self.exit_first_map + access_node: dace_nodes.AccessNode = self.access_node + second_map_entry: dace_nodes.MapEntry = self.entry_second_map + + map_to_promote: dace_nodes.MapEntry = self.map_to_promote(state=state, sdfg=sdfg).map + + # Since we force a promotion of the map we have to store the old parameters + # of the map such that we can later restore them. + org_map_to_promote_params = copy.deepcopy(map_to_promote.params) + org_map_to_promote_ranges = copy.deepcopy(map_to_promote.range) + + try: + # This will lead to a promotion of the map, this is needed that + # Map fusion can actually inspect them. + self.apply(graph=state, sdfg=sdfg) + + # Now create the map fusion object that we can then use to check if + # the fusion is possible or not. + serial_fuser = gtx_transformations.SerialMapFusion( + only_inner_maps=self.only_inner_maps, + only_toplevel_maps=self.only_toplevel_maps, + ) + candidate = { + type(serial_fuser).map_exit1: first_map_exit, + type(serial_fuser).access_node: access_node, + type(serial_fuser).map_entry2: second_map_entry, + } + state_id = sdfg.node_id(state) + serial_fuser.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + + # Now use the serial fuser to see if fusion would succeed + if not serial_fuser.can_be_applied(state, 0, sdfg): + return False + + finally: + # Restore the parameters of the map that we promoted before. + map_to_promote.params = org_map_to_promote_params + map_to_promote.range = org_map_to_promote_ranges + + return True + def map_to_promote( self, state: dace.SDFGState, From b7400a63ddbf2fcbbf21657a1d23299dacd4f08b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 07:49:50 +0200 Subject: [PATCH 216/235] Fixed a small typo in the `TrivialGPUMapPromoter`. --- .../runners/dace_fieldview/transformations/gpu_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 667ae3ccf0..06a3e8690d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -334,9 +334,9 @@ def can_be_applied( - Intermediate access node can only have in and out degree of 1. - The trivial map exit can only have one output. """ - trivial_map_exit: dace_nodes.MapExit = self.trivial_map(state=graph, sdfg=sdfg) + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + trivial_map: dace_nodes.Map = trivial_map_exit.map trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) - trivial_map: dace_nodes.Map = trivial_map_entry.map second_map: dace_nodes.Map = self.second_map_entry.map access_node: dace_nodes.AccessNode = self.access_node From 36a638628d20c8793020aff02e2abc7b6669b577 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 07:50:33 +0200 Subject: [PATCH 217/235] Added tests for the `TrivialGPUMapPromoter`. --- .../transformation_tests/test_gpu_utils.py | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py new file mode 100644 index 0000000000..c9cfd90ad6 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py @@ -0,0 +1,147 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Callable +import dace +import copy +import numpy as np + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview.transformations import ( + gpu_utils as gtx_dace_fieldview_gpu_utils, +) +from . import util + + +def _get_trivial_gpu_promotable( + tasklet_code: str, +) -> tuple[dace.SDFG, dace_nodes.MapEntry, dace_nodes.MapEntry]: + """Returns an SDFG that is suitable to test the `TrivialGPUMapPromoter` promoter. + + The first map is a trivial map (`Map[__trival_gpu_it=0]`) containing a Tasklet, + that does not have an output, but writes a scalar value into `tmp` (output + connector `__out`), the body of this Tasklet can be controlled through the + `tasklet_code` argument. + The second map (`Map[__i0=0:N]`) contains a Tasklet that computes the sum of its + two inputs, the first input is the scalar value inside `tmp` and the second one + is `a[__i0]`, the result is stored in `b[__i0]`. + + Returns: + A tuple, the first element is the SDFG, the second element is the map entry + of the trivial map and the last element is the map entry of the second map. + + Args: + tasklet_code: The body of the Tasklet inside the trivial map. + """ + sdfg = dace.SDFG("TrivailGPUPromotable") + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + + storage_array = dace.dtypes.StorageType.GPU_Global + storage_scalar = dace.dtypes.StorageType.Register + schedule = dace.dtypes.ScheduleType.GPU_Device + + sdfg.add_scalar("tmp", dace.float64, transient=True) + sdfg.add_array("a", shape=("N",), dtype=dace.float64, transient=False, storage=storage_array) + sdfg.add_array("b", shape=("N",), dtype=dace.float64, transient=False, storage=storage_array) + a, b, tmp = (state.add_access(name) for name in ["a", "b", "tmp"]) + + _, trivial_map_entry, _ = state.add_mapped_tasklet( + "trivail_top_tasklet", + map_ranges={"__trivial_gpu_it": "0"}, + inputs={}, + code=tasklet_code, + outputs={"__out": dace.Memlet("tmp[0]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + schedule=schedule, + ) + _, second_map_entry, _ = state.add_mapped_tasklet( + "non_trivial_tasklet", + map_ranges={"__i0": "0:N"}, + inputs={ + "__in0": dace.Memlet("a[__i0]"), + "__in1": dace.Memlet("tmp[0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("b[__i0]")}, + input_nodes={"a": a, "tmp": tmp}, + output_nodes={"b": b}, + external_edges=True, + schedule=schedule, + ) + return sdfg, trivial_map_entry, second_map_entry + + +def test_trivial_gpu_map_promoter(): + """Tests if the GPU map promoter works. + + By using a body such as `__out = 3.0`, the transformation will apply. + """ + sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable("__out = 3.0") + org_second_map_params = list(second_map_entry.map.params) + org_second_map_ranges = copy.deepcopy(second_map_entry.map.range) + + nb_runs = sdfg.apply_transformations_once_everywhere( + gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + validate=True, + validate_all=True, + ) + assert ( + nb_runs == 1 + ), f"Expected that 'TrivialGPUMapPromoter' applies once but it applied {nb_runs}." + trivial_map_params = trivial_map_entry.map.params + trivial_map_ranges = trivial_map_ranges.map.range + second_map_params = second_map_entry.map.params + second_map_ranges = second_map_entry.map.range + + assert ( + second_map_params == org_second_map_params + ), "The transformation modified the parameter of the second map." + assert all( + org_rng == rng for org_rng, rng in zip(org_second_map_ranges, second_map_ranges) + ), "The transformation modified the range of the second map." + assert all( + t_rng == s_rng for t_rng, s_rng in zip(trivial_map_ranges, second_map_ranges, strict=True) + ), "Expected that the ranges are the same; trivial '{trivial_map_ranges}'; second '{second_map_ranges}'." + assert ( + trivial_map_params == second_map_params + ), f"Expected the trivial map to have parameters '{second_map_params}', but it had '{trivial_map_params}'." + assert sdfg.is_valid() + + +def test_trivial_gpu_map_promoter(): + """Test if the GPU promoter does not fuse a special trivial map. + + By using a body such as `__out = __trivial_gpu_it` inside the + Tasklet's body, the map parameter is now used, and thus can not be fused. + """ + sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable( + "__out = __trivial_gpu_it" + ) + org_trivial_map_params = list(trivial_map_entry.map.params) + org_second_map_params = list(second_map_entry.map.params) + + nb_runs = sdfg.apply_transformations_once_everywhere( + gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + validate=True, + validate_all=True, + ) + assert ( + nb_runs == 0 + ), f"Expected that 'TrivialGPUMapPromoter' does not apply but it applied {nb_runs}." + trivial_map_params = trivial_map_entry.map.params + second_map_params = second_map_entry.map.params + assert ( + trivial_map_params == org_trivial_map_params + ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." + assert ( + second_map_params == org_second_map_params + ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." + assert sdfg.is_valid() From d6cde5c12b4e79817e5ea99915514f8ebac70692 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 08:00:50 +0200 Subject: [PATCH 218/235] Updated the map promotion implementation. It now simply copies the range and parameters and no longer augments them. --- .../transformations/map_promoter.py | 40 ++++++------------- .../test_serial_map_promoter.py | 8 ++++ 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index f687e08cb6..19818fd3d1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Union import dace from dace import ( @@ -47,6 +47,9 @@ class BaseMapPromoter(dace_transformation.SingleStateTransformation): To influence what to promote the user must implement the `map_to_promote()` and `source_map()` function. They have to return the map entry node. + The order of the parameter the map to promote has is unspecific, while the + source map is not modified. + Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. @@ -225,37 +228,18 @@ def can_be_applied( def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the actual Map promoting. - Add all parameters that `self.source_map` has but `self.map_to_promote` - lacks to `self.map_to_promote` the range of these new dimensions is taken - from the source map. - The order of the parameters the Map has after the promotion is unspecific. + After this call the map to promote will have the same map parameters + and ranges as the source map has. The function assumes that `can_be_applied()` + returned `True`. """ map_to_promote: dace_nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map source_map: dace_nodes.Map = self.source_map(state=graph, sdfg=sdfg).map - source_params: Sequence[str] = source_map.params - source_ranges: dace_subsets.Range = source_map.range - - missing_params: Sequence[str] = self.missing_map_params( # type: ignore[assignment] # Will never be `None` - map_to_promote=map_to_promote, - source_map=source_map, - be_strict=False, - ) - - # Maps the map parameter of the source map to its index, i.e. which map - # parameter it is. - map_source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_params)} - - promoted_params = list(map_to_promote.params) - promoted_ranges = list(map_to_promote.range.ranges) - - for missing_param in missing_params: - promoted_params.append(missing_param) - promoted_ranges.append(source_ranges[map_source_param_to_idx[missing_param]]) - # Now update the map properties - # This action will also remove the tiles - map_to_promote.range = dace_subsets.Range(promoted_ranges) - map_to_promote.params = promoted_params + # The simplest implementation is just to copy the important parts. + # Note that we only copy the ranges and parameters all other stuff in the + # associated Map object is not modified. + map_to_promote.params = copy.deepcopy(source_map.params) + map_to_promote.range = copy.deepcopy(source_map.range) def missing_map_params( self, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index e224a08a14..5dbd2edf7d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -69,7 +69,9 @@ def test_serial_map_promotion(): assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 1 + assert len(map_entry_1d.map.range) == 1 assert len(map_entry_2d.map.params) == 2 + assert len(map_entry_2d.map.range) == 2 # Now apply the promotion sdfg.apply_transformations( @@ -82,5 +84,11 @@ def test_serial_map_promotion(): assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 + assert len(map_entry_1d.map.range) == 2 assert len(map_entry_2d.map.params) == 2 + assert len(map_entry_2d.map.range) == 2 assert set(map_entry_1d.map.params) == set(map_entry_2d.map.params) + assert all( + rng_1d == rng_2d + for rng_1d, rng_2d in zip(map_entry_1d.map.range.ranges, map_entry_2d.map.range.ranges) + ) From 32d3883ca8b852ba56bc6fc623975b4af7ab8e81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:11:50 +0200 Subject: [PATCH 219/235] Update docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- .../ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index 6b321a452e..9aae1114b1 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -36,7 +36,7 @@ The following rules, especially affects transformations and how they operate: - [Rational]: As a consequence the number of "interstate transients" (transients that are used in multiple states) remains constant during intrastate transformations. - [Note 1]: It is allowed to run them after one another, as long as they are strictly separated. - [Note 2]: It is allowed that _intrastate_ transformation act in a way to allow state fusion by later intrastate transformations. - - [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2. + - [Note 3]: The DaCe simplification pass violates this rule; for that reason, this pass must always be called on its own. See also rule 2. 2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed, the only valid way to call simplify is to call the `gt_simplify()` function provided by GT4Py. - [Rational]: It was observed that some sub passes in simplify have a negative impact and that additional passes might be needed in the future. From 201c8e2b8db0d5872bbc74404726d072cb329368 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 15:27:06 +0200 Subject: [PATCH 220/235] Corrected the ADR. --- ...Canonical_SDFG_in_GT4Py_Transformations.md | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index 9aae1114b1..679bfba00f 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -13,44 +13,44 @@ Their main intent is to reduce the complexity of the GT4Py specific transformati ## Context -The canonical form that is outlined in this document was mainly designed from the perspective of the optimization pipeline. -Thus it emphasizes a form that is can be handled in a simple and efficient way by a transformation. -In the pipeline we distinguishes between +The canonical is outlined in this document was mainly designed from the perspective of the optimization pipeline. +Thus it emphasizes a form that can be handled in a simple and efficient way by a transformation. +In the pipeline we distinguish between: -- Intrastate optimization: The optimization of the data flow within states. -- Interstate optimization: The optimization between states, this are transformations that are _intended_ to _reduce_ the number of states. +- Intrastate optimization: optimization of the data flow within states. +- Interstate optimization: optimization between states, these are transformations that are _intended_ to _reduce_ the number of states. The current (GT4Py) pipeline mainly focus on intrastate optimization and relays on DaCe, especially its simplify pass, for interstate optimizations. ## Decision The canonical form is defined by several rules that affect different aspects of an SDFG and what a transformation can assume. -This allows to simplify the implementation of certain transformations. +This allows simplifying the implementation of certain transformations. #### General Aspects -The following rules, especially affects transformations and how they operate: +The following rules especially affects transformations and how they operate: 1. Intrastate transformation and interstate transformations must run separately and can not be mixed in the same (DaCe) pipeline. - - [Rational]: As a consequence the number of "interstate transients" (transients that are used in multiple states) remains constant during intrastate transformations. - - [Note 1]: It is allowed to run them after one another, as long as they are strictly separated. - - [Note 2]: It is allowed that _intrastate_ transformation act in a way to allow state fusion by later intrastate transformations. - - [Note 3]: The DaCe simplification pass violates this rule; for that reason, this pass must always be called on its own. See also rule 2. + - [Rationale]: As a consequence the number of "interstate transients" (transients that are used in multiple states) remains constant during intrastate transformations. + - [Note 1]: It is allowed to run them one after another, as long as they are strictly separated. + - [Note 2]: It is allowed for an _intrastate_ transformation to act in a way that allows state fusion by later intrastate transformations. + - [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2. -2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed, the only valid way to call simplify is to call the `gt_simplify()` function provided by GT4Py. - - [Rational]: It was observed that some sub passes in simplify have a negative impact and that additional passes might be needed in the future. - By only using a single function later modifications to simplify are easy. +2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py. + - [Rationale]: It was observed that some sub-passes in _simplify()_ have a negative impact and that additional passes might be needed in the future. + By using a single function later modifications to _simplify()_ are easy. - [Note]: One issue is that the remove redundant array transformation is not able to handle all cases. #### Global Memory The only restriction we impose on global memory is: -3. The same global memory is allowed to be used as input and output at the same time, iff the output depends _elementwise_ on the input. - - [Rational 1]: Allows to remove double buffering, that DaCe may not remove, see also rule 2. - - [Rational 2]: This formulation allows to write expressions such as `a += 1`, with only memory for `a`. - Phrased more technically using global memory for input and output is allowed iff the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent. +3. The same global memory is allowed to be used as input and output at the same time, if and only if the output depends _elementwise_ on the input. + - [Rationale 1]: This allows the removal of double buffering, that DaCe may not remove. See also rule 2. + - [Rationale 2]: This formulation allows writing expressions such as `a += 1`, with only memory for `a`. + Phrased more technically, using global memory for input and output is allowed if and only if the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent. - [Note]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. #### State Machine @@ -59,21 +59,21 @@ For the SDFG state machine we assume that: 4. An interstate edge can only access scalars, i.e. use them in their assignment or condition expressions, but not arrays, even if they have shape `(1,)`. - - [Rational]: If an array is also used in interstate edges it became very tedious to verify if the array could be removed or not. - - [Note]: Running simplify might actually result in the violation of this rule, see note of rule 9. + - [Rationale]: If an array is also used in interstate edges it became very tedious to verify if the array could be removed or not. + - [Note]: Running _simplify()_ might actually result in the violation of this rule, see note of rule 9. 5. The state graph does not contain any cycles, i.e. the implementation of a for/while loop using states is not allowed, the new loop construct or serial maps must be used in that case. - - [Rational]: This is a simplification that makes it much simpler to define "later in the computation" means as we will never have a cycle. + - [Rationale]: This is a simplification that makes it much simpler to define what "later in the computation" means, as we will never have a cycle. - [Note]: Currently the code generator does not support the `LoopRegion` construct and it is transformed to a state machine. #### Transients The rules we impose on transients are a bit more complicated, however, while sounding restrictive, they are very permissive. -It is important that these rules only have to be met after after simplify was called once on the SDFG: +It is important to note that these rules only have to be met after after _simplify()_ was called once on the SDFG: -6. Downstream of a write access, i.e. in all states that follows the state the access node is located in, there are no other access nodes that are used to write to the same array. +6. Downstream of a write access, i.e., in all states that follow the state where the access node is located, there are no other access nodes that are used to write to the same array. - - [Rational 1]: This rule together with rule 7 and 8 essentially boils down to ensure that the assignment in the SDFG follows SSA style, while allowing for expressions such as: + - [Rationale 1]: This rule, together with rule 7 and 8, essentially ensures that the assignment in the SDFG follows SSA style, while allowing for expressions such as: ```python if cond: @@ -84,44 +84,44 @@ It is important that these rules only have to be met after after simplify was ca (**NOTE:** This could also be done with references, however, they are strongly discouraged.) - - [Rational 2]: This still allows reductions with WCR as they write to the same access node and loops, whose body modifies a transient that outlives the loop body, as they use the same access node. + - [Rationale 2]: This still allows reductions with WCR as they write to the same access node and loops, whose body modifies a transient that outlives the loop body, as they use the same access node. 7. It is _recommended_ that a write access node should only have one incoming edge. - - [Rational]: This case is handled poorly by some DaCe transformations, thus we should avoid them as much as possible. + - [Rationale]: This case is handled poorly by some DaCe transformations, thus we should avoid them as much as possible. 8. No two access nodes in a state can refer to the same array. - - [Rational]: Together with rule 5 this guarantees SSA style. - - [Note]: An SDFG can still be constructed using different access node for the same underlying data; simplify will combine them. + - [Rationale]: Together with rule 5 this guarantees SSA style. + - [Note]: An SDFG can still be constructed using different access node for the same underlying data; _simplify()_ will combine them. 9. Every access node that reads from an array (having an outgoing edge) that was not written to in the same state must be a source node. - - [Rational]: Together with rule 1, 4, 5, 6, 7 and 8 this simplifies the check if a transient can be safely removed or if it is used somewhere else. - These rules guarantee that the number of "interstate transients" remains constant and these set is given by the _set of source nodes and all access nodes that have an outgoing degree larger than one_. - - [Note]: To prevent some issues caused by the violation of rule 4 by simplify, this set is extended with the transient sink nodes and all scalars. - Excess interstate transients, that will be kept alive that way, will be removed by later calls to simplify. + - [Rationale]: Together with rule 1, 4, 5, 6, 7 and 8 this simplifies checking if a transient can be safely removed or if it is used somewhere else. + These rules guarantee that the number of "interstate transients" remains constant and this set is given by the _set of source nodes and all access nodes that have an outgoing degree larger than one_. + - [Note]: To prevent some issues caused by the violation of rule 4 by _simplify()_, this set is extended with the transient sink nodes and all scalars. + Excess interstate transients, that will be kept alive that way, will be removed by later calls to _simplify()_. -10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should be _preferable_ `dace.dtypes.StorageType.Register`. - - [Rational 1]: Makes optimizations operating inside a maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. - - [Rational 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. +10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should _preferably_ be `dace.dtypes.StorageType.Register`. + - [Rationale 1]: This makes optimizations operating inside maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. + - [Rationale 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. #### Maps For maps we assume the following: -11. The names of map variables (iteration variable) follow the following pattern. +11. The names of map variables (iteration variables) follow the following pattern. - 11.1: All map variables iterating over the same dimension (disregarding the actual range), have the same deterministic name, that includes the `gtx.Dimension.value` string. - 11.2: The name of horizontal dimensions (`kind` attribute) always end in `__gtx_horizontal`. - 11.3: The name of vertical dimensions (`kind` attribute) always end in `__gtx_vertical`. - 11.4: The name of local dimensions always ends in `__gtx_localdim`. - 11.5: No transformation is allowed to modify the name of an iteration variable that follows rules 11.2, 11.3 or 11.4. - - [Rational]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12. + - [Rationale]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12. 12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range. - - [Rational 1]: Because of rule 11 we will only fuse maps that actually makes sense to fuse. - - [Rational 2]: This allows to fuse maps without performing a renaming on the map variables. + - [Rationale 1]: Because of rule 11, we will only fuse maps that actually makes sense to fuse. + - [Rationale 2]: This allows to fusing maps without performing a renaming on the map variables. - [Note]: This rule might be dropped in the future. ## Consequences From 210a8d9b5c203b366e0057877c1ee10941d9a1cb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 15:41:18 +0200 Subject: [PATCH 221/235] Second appling. --- .../transformations/auto_opt.py | 21 +++++++++---------- .../transformations/k_blocking.py | 4 ++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index e773ccedeb..0960ede470 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -111,13 +111,8 @@ def gt_auto_optimize( 1. Some general simplification transformations, beyond classical simplify, are applied to the SDFG. - 2. In this phase the function tries to reduce the number of maps. This - process mostly relies on the map fusion transformation. If - `aggressive_fusion` is set the function will also promote certain Maps, to - make them fusable. For this it will add dummy dimensions. However, currently - the function will only add horizonal dimensions. - In this phase some optimizations inside the bigger kernels themselves might - be applied as well. + 2. Tries to create larger kernels by fusing smaller ones, see + `gt_auto_fuse_top_level_maps()` for more details. 3. After the function created big kernels/maps it will apply some optimization, inside the kernels itself. For example fuse maps inside them. 4. Afterwards it will process the map ranges and iteration order. For this @@ -140,8 +135,8 @@ def gt_auto_optimize( sdfg: The SDFG that should be optimized in place. gpu: Optimize for GPU or CPU. leading_dim: Leading dimension, indicates where the stride is 1. - aggressive_fusion: Be more aggressive in fusion, will lead to the promotion - of certain maps. + aggressive_fusion: Be more aggressive during phase 2, will lead to the promotion + of certain maps (over computation) but will lead to larger kernels. max_optimization_rounds_p2: Maximum number of optimization rounds in phase 2. make_persistent: Turn all transients to persistent lifetime, thus they are allocated over the whole lifetime of the program, even if the kernel exits. @@ -198,7 +193,7 @@ def gt_auto_optimize( # Phase 2: Kernel Creation # Try to create kernels as large as possible. - sdfg = _gt_auto_optimize_phase_2( + sdfg = gt_auto_fuse_top_level_maps( sdfg=sdfg, aggressive_fusion=aggressive_fusion, max_optimization_rounds=max_optimization_rounds_p2, @@ -277,7 +272,7 @@ def gt_auto_optimize( return sdfg -def _gt_auto_optimize_phase_2( +def gt_auto_fuse_top_level_maps( sdfg: dace.SDFG, aggressive_fusion: bool = True, max_optimization_rounds: int = 100, @@ -306,6 +301,10 @@ def _gt_auto_optimize_phase_2( performed. validate: Perform validation during the steps. validate_all: Perform extensive validation. + + Note: + Calling this function directly is most likely an error. Instead you should + call `gt_auto_optimize()` directly. """ # Compute the SDFG hash to see if something has changed. sdfg_hash = sdfg.hash_sdfg() diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 1e0ea319c7..ce67cb504a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -143,8 +143,8 @@ def apply( sdfg=sdfg, ) - @staticmethod def _rewire_map_scope( + self, outer_entry: dace_nodes.MapEntry, outer_exit: dace_nodes.MapExit, inner_entry: dace_nodes.MapEntry, @@ -469,8 +469,8 @@ def partition_map_output( return (independent_nodes, dependent_nodes) - @staticmethod def classify_node( + self, node_to_classify: dace_nodes.Node, outer_entry: dace_nodes.MapEntry, blocking_parameter: str, From 017fc9f1147dea4619ca0b07c5fc9d8d85988ad6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 15:53:09 +0200 Subject: [PATCH 222/235] Renamed the `KBlocking` to `LoopBlocking`. --- .../transformations/__init__.py | 4 ++-- .../transformations/auto_opt.py | 7 +++---- .../{k_blocking.py => loop_blocking.py} | 19 ++++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/transformations/{k_blocking.py => loop_blocking.py} (97%) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index b698781b1a..53fa1eee05 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -14,7 +14,7 @@ from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize -from .k_blocking import KBlocking +from .loop_blocking import LoopBlocking from .map_orderer import MapIterationOrder from .map_promoter import SerialMapPromoter from .map_serial_fusion import SerialMapFusion @@ -22,7 +22,7 @@ __all__ = [ "GPUSetBlockSize", - "KBlocking", + "LoopBlocking", "MapIterationOrder", "SerialMapFusion", "SerialMapPromoter", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 0960ede470..e19dccc67f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -118,9 +118,8 @@ def gt_auto_optimize( 4. Afterwards it will process the map ranges and iteration order. For this the function assumes that the dimension indicated by `leading_dim` is the one with stride one. - 5. If requested the function will now apply blocking, on the dimension indicated - by `leading_dim`. (The reason that it is not done in the kernel optimization - phase is a restriction dictated by the implementation.) + 5. If requested the function will now apply loop blocking, on the dimension + indicated by `leading_dim`. 6. If requested the SDFG will be transformed to GPU. For this the `gt_gpu_transformation()` function is used, that might apply several other optimizations. @@ -226,7 +225,7 @@ def gt_auto_optimize( # Phase 5: Apply blocking if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( - gtx_transformations.KBlocking( + gtx_transformations.LoopBlocking( blocking_size=blocking_size, blocking_parameter=blocking_dim, ), diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index ce67cb504a..b9fbae74e4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -23,20 +23,20 @@ @dace_properties.make_properties -class KBlocking(dace_transformation.SingleStateTransformation): - """Applies k-Blocking with separation on a Map. +class LoopBlocking(dace_transformation.SingleStateTransformation): + """Applies loop blocking, also known as k-blocking, on a Map. This transformation takes a multidimensional Map and performs blocking on a - single dimension, that is commonly called "k". All dimensions except `k` are - unaffected by this transformation. In the outer Map will be replace the `k` - range, currently `k = 0:N`, with `__coarse_k = 0:N:B`, where `N` is the + single dimension, the loop variable is called `I` here. All dimensions except + `I` are unaffected by this transformation. In the outer Map will be replace the + `I` range, currently `I = 0:N`, with `__coarse_I = 0:N:B`, where `N` is the original size of the range and `B` is the blocking size. The transformation - will then create an inner sequential map with `k = __coarse_k:(__coarse_k + B)`. + will then create an inner sequential map with `I = __coarse_I:(__coarse_I + B)`. What makes this transformation different from simple blocking, is that the inner map will not just be inserted right after the outer Map. Instead the transformation will first identify all nodes that does not depend - on the blocking parameter and relocate them between the outer and inner map. + on the blocking parameter `I` and relocate them between the outer and inner map. Thus these operations will only be performed once, per inner loop. Args: @@ -51,12 +51,13 @@ class KBlocking(dace_transformation.SingleStateTransformation): blocking_size = dace_properties.Property( dtype=int, allow_none=True, - desc="Size of the inner k Block.", + desc="Size of the inner blocks; 'B' in the above description.", ) blocking_parameter = dace_properties.Property( dtype=str, allow_none=True, - desc="Name of the iteration variable on which to block (must be an exact match).", + desc="Name of the iteration variable on which to block (must be an exact match);" + " 'I' in the above description.", ) outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) From 20da8582770eb2ecf80b31148833fc62430226e6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 16:15:32 +0200 Subject: [PATCH 223/235] Made some smaller modification. --- .../runners/dace_fieldview/transformations/util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index a54c882842..b6e02fd9a8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -56,9 +56,11 @@ def all_nodes_between( """ def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: - if reverse: - return (edge.src for edge in graph.in_edges(node)) - return (edge.dst for edge in graph.out_edges(node)) + return ( + (edge.src for edge in graph.in_edges(node)) + if reverse + else (edge.dst for edge in graph.out_edges(node)) + ) if reverse: begin, end = end, begin From 8c31694ddce21eb4caa60d77aa7b1a367f2e6a62 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 16:17:29 +0200 Subject: [PATCH 224/235] Added the comment Enrique mentioned. --- .../runners_tests/dace/transformation_tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py index 0e85e1d9d1..9eb387f5c5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -30,6 +30,10 @@ def set_dace_settings() -> Generator[None, None, None]: especially inside `can_be_applied()` are not ignored. - `compiler.allow_view_arguments` allow that NumPy views can be passed to `CompiledSDFG` objects as arguments. + + Note: + This fixture will be automatically used by all tests inside this folder and + its subfolders due to the autouse option. """ with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) From c8ecd2511babdfd61f2272e960786c9ab95ac1b4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 16:37:10 +0200 Subject: [PATCH 225/235] Removed the auto use fixture, it is now imported explicitly. --- .../runners_tests/dace/transformation_tests/conftest.py | 8 ++------ .../dace/transformation_tests/test_gpu_utils.py | 3 +++ .../dace/transformation_tests/test_k_blocking.py | 3 +++ .../dace/transformation_tests/test_map_fusion.py | 2 ++ .../dace/transformation_tests/test_serial_map_promoter.py | 4 ++++ 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py index 9eb387f5c5..5d5fb9923a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -21,7 +21,7 @@ ) -@pytest.fixture(autouse=True) +@pytest.fixture() def set_dace_settings() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. @@ -30,12 +30,8 @@ def set_dace_settings() -> Generator[None, None, None]: especially inside `can_be_applied()` are not ignored. - `compiler.allow_view_arguments` allow that NumPy views can be passed to `CompiledSDFG` objects as arguments. - - Note: - This fixture will be automatically used by all tests inside this folder and - its subfolders due to the autouse option. """ with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=True) + dace.Config.set("optimizer", "match_exception", value=False) dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py index c9cfd90ad6..8b49a430d4 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py @@ -10,6 +10,7 @@ import dace import copy import numpy as np +import pytest from dace.sdfg import nodes as dace_nodes @@ -18,6 +19,8 @@ ) from . import util +pytestmark = pytest.mark.usefixtures("set_dace_settings") + def _get_trivial_gpu_promotable( tasklet_code: str, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py index b1b7b812d6..6c2cc37261 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py @@ -10,6 +10,7 @@ import dace import copy import numpy as np +import pytest from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation @@ -17,6 +18,8 @@ transformations as gtx_transformations, ) +pytestmark = pytest.mark.usefixtures("set_dace_settings") + def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np.ndarray]]: """Creates a simple SDFG. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index 8d8a108765..1c43e00efd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -21,6 +21,8 @@ ) from . import util +pytestmark = pytest.mark.usefixtures("set_dace_settings") + def _make_serial_sdfg_1( N: str | int, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index 5dbd2edf7d..c483866a35 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -10,6 +10,7 @@ import dace import copy import numpy as np +import pytest from dace.sdfg import nodes as dace_nodes @@ -19,6 +20,9 @@ from . import util +pytestmark = pytest.mark.usefixtures("set_dace_settings") + + def test_serial_map_promotion(): """Tests the serial Map promotion transformation.""" N = 10 From 7aed88f5ce46f13a0b1750d070f324882d4c923a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 16:45:15 +0200 Subject: [PATCH 226/235] Forgot to rename the `KBlocking` also in the tests. --- .../{test_k_blocking.py => test_loop_blocking.py} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/{test_k_blocking.py => test_loop_blocking.py} (97%) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py similarity index 97% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py index 6c2cc37261..812225d273 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py @@ -150,7 +150,7 @@ def test_only_dependent(): # Apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) @@ -214,7 +214,7 @@ def test_intermediate_access_node(): # Apply the transformation. sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) @@ -252,7 +252,7 @@ def test_chained_access() -> None: # Apply the transformation. ret = sdfg.apply_transformations_repeated( - gtx_transformations.KBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) From 2a8494ade59259772cfc0dbdfdf3a5f4a0438a90 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 10:03:05 +0200 Subject: [PATCH 227/235] Further modifications. --- .../transformations/loop_blocking.py | 16 ++++++---------- .../dace_fieldview/transformations/util.py | 8 ++------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index b9fbae74e4..dda2b802e9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -426,7 +426,6 @@ def partition_map_output( `used_symbol` properties of a Tasklet to determine if a Tasklet is dependent. """ outer_entry: dace_nodes.MapEntry = self.outer_entry - blocking_parameter: str = self.blocking_parameter independent_nodes: set[dace_nodes.Node] = set() # `\mathcal{I}` while True: @@ -443,10 +442,9 @@ def partition_map_output( # Now classify each node found_new_independent_node = False for node_to_classify in nodes_to_classify: - class_res = self.classify_node( + class_res = self._classify_node( node_to_classify=node_to_classify, outer_entry=outer_entry, - blocking_parameter=blocking_parameter, independent_nodes=independent_nodes, state=state, sdfg=sdfg, @@ -470,11 +468,10 @@ def partition_map_output( return (independent_nodes, dependent_nodes) - def classify_node( + def _classify_node( self, node_to_classify: dace_nodes.Node, outer_entry: dace_nodes.MapEntry, - blocking_parameter: str, independent_nodes: set[dace_nodes.Node], state: dace.SDFGState, sdfg: dace.SDFG, @@ -484,10 +481,10 @@ def classify_node( The general rule to classify if a node is independent are: - The node must be a Tasklet or an AccessNode, in all other cases the partition does not exist. - - `free_symbols` of the nodes shall not contain the `blocking_parameter`. + - `free_symbols` of the nodes shall not contain the blocking parameter. - All incoming _empty_ edges must be connected to the map entry. - A node either has only empty Memlets or none of them. - - Incoming Memlets does not depend on the `blocking_parameter`. + - Incoming Memlets does not depend on the blocking parameter. - All incoming edges must start either at `outer_entry` or at dependent nodes. - All output Memlets are non empty. @@ -501,7 +498,6 @@ def classify_node( Args: node_to_classify: The node that should be classified. outer_entry: The entry of the map that should be partitioned. - blocking_parameter: The iteration parameter that should be blocked. independent_nodes: The set of nodes that was already classified as independent, in which case it is added to `independent_nodes`. state: The state containing the map. @@ -569,7 +565,7 @@ def classify_node( # Test if the body of the Tasklet depends on the block variable. if ( isinstance(node_to_classify, dace_nodes.Tasklet) - and blocking_parameter in node_to_classify.free_symbols + and self.blocking_parameter in node_to_classify.free_symbols ): return False @@ -594,7 +590,7 @@ def classify_node( # If a subset needs the block variable then the node is not independent # but dependent. - if any(blocking_parameter in subset.free_symbols for subset in subsets_to_inspect): + if any(self.blocking_parameter in subset.free_symbols for subset in subsets_to_inspect): return False # The edge must either originate from `outer_entry` or from an independent diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index b6e02fd9a8..29bae7bbe0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -139,12 +139,8 @@ def find_downstream_consumers( new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) to_visit.extend(new_edges) - elif only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): - # We are only interested in Tasklets but have not found one. Thus we - # ignore the node. - pass - - else: + elif isinstance(next_node, dace_nodes.Tasklet) or not only_tasklets: + # We have found a consumer. found.add((next_node, curr_edge)) return found From 2cfbe20fef71f31bd2e0bb77c89bfbb3fe2c6c50 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 11:52:34 +0200 Subject: [PATCH 228/235] Applied the last comments. --- .../transformations/loop_blocking.py | 363 +++++++++--------- 1 file changed, 181 insertions(+), 182 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index dda2b802e9..497a365f75 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -133,7 +133,7 @@ def apply( (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps(graph) # Reconnect the edges - self._rewire_map_scope( + _rewire_map_scope( outer_entry=outer_entry, outer_exit=outer_exit, inner_entry=inner_entry, @@ -144,180 +144,6 @@ def apply( sdfg=sdfg, ) - def _rewire_map_scope( - self, - outer_entry: dace_nodes.MapEntry, - outer_exit: dace_nodes.MapExit, - inner_entry: dace_nodes.MapEntry, - inner_exit: dace_nodes.MapExit, - independent_nodes: set[dace_nodes.Node], - dependent_nodes: set[dace_nodes.Node], - state: dace.SDFGState, - sdfg: dace.SDFG, - ) -> None: - """Rewire the edges inside the scope defined by the outer map. - - The function assumes that the outer and inner map were obtained by a call - to `_prepare_inner_outer_maps()`. The function will now rewire the connections of these - nodes such that the dependent nodes are inside the scope of the inner map, - while the independent nodes remain outside. - - Args: - outer_entry: The entry node of the outer map. - outer_exit: The exit node of the outer map. - inner_entry: The entry node of the inner map. - inner_exit: The exit node of the inner map. - independent_nodes: The set of independent nodes. - dependent_nodes: The set of dependent nodes. - state: The state of the map. - sdfg: The SDFG we operate on. - """ - - # Contains the nodes that are already have been handled. - relocated_nodes: set[dace_nodes.Node] = set() - - # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in independent_nodes: - for out_edge in state.out_edges(independent_node): - edge_dst: dace_nodes.Node = out_edge.dst - relocated_nodes.add(edge_dst) - - # If destination of this edge is also independent we do not need - # to handle it, because that node will also be before the new - # inner serial map. - if edge_dst in independent_nodes: - continue - - # Now split `out_edge` such that it passes through the new inner entry. - # We do not need to modify the subsets, i.e. replacing the variable - # on which we block, because the node is independent and the outgoing - # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, - ) - # TODO(phimuell): Check if there might be a subset error. - state.add_edge( - inner_entry, - "OUT_" + new_map_conn, - out_edge.dst, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) - - # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in dependent_nodes: - for in_edge in state.in_edges(dependent_node): - edge_src: dace_nodes.Node = in_edge.src - - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. - if edge_src is inner_entry: - assert dependent_node in relocated_nodes - continue - - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." - dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, - ) - - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr - ) - - else: - # This is the first time we found this connection. - # so we just create the edge. - state.add_edge( - outer_entry, - "OUT_" + edge_conn, - inner_entry, - "IN_" + edge_conn, - copy.deepcopy(in_edge.data), - ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) - - # In certain cases it might happen that we need to create an empty - # Memlet between the outer map entry and the inner one. - if state.in_degree(inner_entry) == 0: - state.add_edge( - outer_entry, - None, - inner_entry, - None, - dace.Memlet(), - ) - - # Handle the Map exits - # This is simple reconnecting, there would be possibilities for improvements - # but we do not use them for now. - for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_dst=inner_exit, - new_dst_conn="IN_" + edge_conn, - ) - state.add_edge( - inner_exit, - "OUT_" + edge_conn, - outer_exit, - in_edge.dst_conn, - copy.deepcopy(in_edge.data), - ) - inner_exit.add_in_connector("IN_" + edge_conn) - inner_exit.add_out_connector("OUT_" + edge_conn) - - # TODO(phimuell): Use a less expensive method. - dace.sdfg.propagation.propagate_memlets_state(sdfg, state) - def _prepare_inner_outer_maps( self, state: dace.SDFGState, @@ -425,7 +251,6 @@ def partition_map_output( - Currently this function only considers the input Memlets and the `used_symbol` properties of a Tasklet to determine if a Tasklet is dependent. """ - outer_entry: dace_nodes.MapEntry = self.outer_entry independent_nodes: set[dace_nodes.Node] = set() # `\mathcal{I}` while True: @@ -433,7 +258,7 @@ def partition_map_output( # - All nodes adjacent to `outer_entry` # - All nodes adjacent to independent nodes. nodes_to_classify: set[dace_nodes.Node] = { - edge.dst for edge in state.out_edges(outer_entry) + edge.dst for edge in state.out_edges(self.outer_entry) } for independent_node in independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) @@ -444,7 +269,6 @@ def partition_map_output( for node_to_classify in nodes_to_classify: class_res = self._classify_node( node_to_classify=node_to_classify, - outer_entry=outer_entry, independent_nodes=independent_nodes, state=state, sdfg=sdfg, @@ -463,7 +287,9 @@ def partition_map_output( # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. dependent_nodes: set[dace_nodes.Node] = { - edge.dst for edge in state.out_edges(outer_entry) if edge.dst not in independent_nodes + edge.dst + for edge in state.out_edges(self.outer_entry) + if edge.dst not in independent_nodes } return (independent_nodes, dependent_nodes) @@ -471,7 +297,6 @@ def partition_map_output( def _classify_node( self, node_to_classify: dace_nodes.Node, - outer_entry: dace_nodes.MapEntry, independent_nodes: set[dace_nodes.Node], state: dace.SDFGState, sdfg: dace.SDFG, @@ -485,7 +310,7 @@ def _classify_node( - All incoming _empty_ edges must be connected to the map entry. - A node either has only empty Memlets or none of them. - Incoming Memlets does not depend on the blocking parameter. - - All incoming edges must start either at `outer_entry` or at dependent nodes. + - All incoming edges must start either at `self.outer_entry` or at dependent nodes. - All output Memlets are non empty. Returns: @@ -497,12 +322,12 @@ def _classify_node( Args: node_to_classify: The node that should be classified. - outer_entry: The entry of the map that should be partitioned. independent_nodes: The set of nodes that was already classified as independent, in which case it is added to `independent_nodes`. state: The state containing the map. sdfg: The SDFG that is processed. """ + outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. # We are only able to handle certain kind of nodes, so screening them. if isinstance(node_to_classify, dace_nodes.Tasklet): @@ -601,3 +426,177 @@ def _classify_node( # Loop ended normally, thus we classify the node as independent. independent_nodes.add(node_to_classify) return True + + +def _rewire_map_scope( + outer_entry: dace_nodes.MapEntry, + outer_exit: dace_nodes.MapExit, + inner_entry: dace_nodes.MapEntry, + inner_exit: dace_nodes.MapExit, + independent_nodes: set[dace_nodes.Node], + dependent_nodes: set[dace_nodes.Node], + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Rewire the edges inside the scope defined by the outer map. + + The function assumes that the outer and inner map were obtained by a call + to `_prepare_inner_outer_maps()`. The function will now rewire the connections of these + nodes such that the dependent nodes are inside the scope of the inner map, + while the independent nodes remain outside. + + Args: + outer_entry: The entry node of the outer map. + outer_exit: The exit node of the outer map. + inner_entry: The entry node of the inner map. + inner_exit: The exit node of the inner map. + independent_nodes: The set of independent nodes. + dependent_nodes: The set of dependent nodes. + state: The state of the map. + sdfg: The SDFG we operate on. + """ + + # Contains the nodes that are already have been handled. + relocated_nodes: set[dace_nodes.Node] = set() + + # We now handle all independent nodes, this means that all of their + # _output_ edges have to go through the new inner map and the Memlets need + # modifications, because of the block parameter. + for independent_node in independent_nodes: + for out_edge in state.out_edges(independent_node): + edge_dst: dace_nodes.Node = out_edge.dst + relocated_nodes.add(edge_dst) + + # If destination of this edge is also independent we do not need + # to handle it, because that node will also be before the new + # inner serial map. + if edge_dst in independent_nodes: + continue + + # Now split `out_edge` such that it passes through the new inner entry. + # We do not need to modify the subsets, i.e. replacing the variable + # on which we block, because the node is independent and the outgoing + # new inner map entry iterate over the blocked variable. + new_map_conn = inner_entry.next_connector() + dace_helpers.redirect_edge( + state=state, + edge=out_edge, + new_dst=inner_entry, + new_dst_conn="IN_" + new_map_conn, + ) + # TODO(phimuell): Check if there might be a subset error. + state.add_edge( + inner_entry, + "OUT_" + new_map_conn, + out_edge.dst, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + inner_entry.add_in_connector("IN_" + new_map_conn) + inner_entry.add_out_connector("OUT_" + new_map_conn) + + # Now we handle the dependent nodes, they differ from the independent nodes + # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. + for dependent_node in dependent_nodes: + for in_edge in state.in_edges(dependent_node): + edge_src: dace_nodes.Node = in_edge.src + + # Since the independent nodes were already processed, and they process + # their output we have to check for this. We do this by checking if + # the source of the edge is the new inner map entry. + if edge_src is inner_entry: + assert dependent_node in relocated_nodes + continue + + # A dependent node has at least one connection to the outer map entry. + # And these are the only connections that we must handle, since other + # connections come from independent nodes, and were already handled + # or are inner nodes. + if edge_src is not outer_entry: + continue + + # If we encounter an empty Memlet we just just attach it to the + # new inner map entry. Note the partition function ensures that + # either all edges are empty or non. + if in_edge.data.is_empty(): + assert ( + edge_src is outer_entry + ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) + continue + + # Because of the definition of a dependent node and the processing + # order, their incoming edges either point to the outer map or + # are already handled. + assert ( + edge_src is outer_entry + ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." + edge_conn: str = in_edge.src_conn[4:] + + # Must be before the handling of the modification below + # Note that this will remove the original edge from the SDFG. + dace_helpers.redirect_edge( + state=state, + edge=in_edge, + new_src=inner_entry, + new_src_conn="OUT_" + edge_conn, + ) + + # In a valid SDFG only one edge can go into an input connector of a Map. + if "IN_" + edge_conn in inner_entry.in_connectors: + # We have found this edge multiple times already. + # To ensure that there is no error, we will create a new + # Memlet that reads the whole array. + piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) + data_name = piping_edge.data.data + piping_edge.data = dace.Memlet.from_array( + data_name, sdfg.arrays[data_name], piping_edge.data.wcr + ) + + else: + # This is the first time we found this connection. + # so we just create the edge. + state.add_edge( + outer_entry, + "OUT_" + edge_conn, + inner_entry, + "IN_" + edge_conn, + copy.deepcopy(in_edge.data), + ) + inner_entry.add_in_connector("IN_" + edge_conn) + inner_entry.add_out_connector("OUT_" + edge_conn) + + # In certain cases it might happen that we need to create an empty + # Memlet between the outer map entry and the inner one. + if state.in_degree(inner_entry) == 0: + state.add_edge( + outer_entry, + None, + inner_entry, + None, + dace.Memlet(), + ) + + # Handle the Map exits + # This is simple reconnecting, there would be possibilities for improvements + # but we do not use them for now. + for in_edge in state.in_edges(outer_exit): + edge_conn = in_edge.dst_conn[3:] + dace_helpers.redirect_edge( + state=state, + edge=in_edge, + new_dst=inner_exit, + new_dst_conn="IN_" + edge_conn, + ) + state.add_edge( + inner_exit, + "OUT_" + edge_conn, + outer_exit, + in_edge.dst_conn, + copy.deepcopy(in_edge.data), + ) + inner_exit.add_in_connector("IN_" + edge_conn) + inner_exit.add_out_connector("OUT_" + edge_conn) + + # TODO(phimuell): Use a less expensive method. + dace.sdfg.propagation.propagate_memlets_state(sdfg, state) From 3e7d09fa4f4506c522e76bccc1c86b60f1694ed3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:36:08 +0200 Subject: [PATCH 229/235] Updated Edoardo's comments. --- .../ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md | 8 ++++---- .../runners/dace_fieldview/gtir_to_sdfg.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index 679bfba00f..cdefe0022c 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -69,7 +69,7 @@ For the SDFG state machine we assume that: #### Transients The rules we impose on transients are a bit more complicated, however, while sounding restrictive, they are very permissive. -It is important to note that these rules only have to be met after after _simplify()_ was called once on the SDFG: +It is important to note that these rules only have to be met after _simplify()_ was called once on the SDFG: 6. Downstream of a write access, i.e., in all states that follow the state where the access node is located, there are no other access nodes that are used to write to the same array. @@ -102,7 +102,7 @@ It is important to note that these rules only have to be met after after _simpli - [Note]: To prevent some issues caused by the violation of rule 4 by _simplify()_, this set is extended with the transient sink nodes and all scalars. Excess interstate transients, that will be kept alive that way, will be removed by later calls to _simplify()_. -10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should _preferably_ be `dace.dtypes.StorageType.Register`. +10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should either be `dace.dtypes.StorageType.Default` or _preferably_ `dace.dtypes.StorageType.Register`. - [Rationale 1]: This makes optimizations operating inside maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. - [Rationale 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. @@ -112,7 +112,7 @@ For maps we assume the following: 11. The names of map variables (iteration variables) follow the following pattern. - - 11.1: All map variables iterating over the same dimension (disregarding the actual range), have the same deterministic name, that includes the `gtx.Dimension.value` string. + - 11.1: All map variables iterating over the same dimension (disregarding the actual range) have the same deterministic name, that includes the `gtx.Dimension.value` string. - 11.2: The name of horizontal dimensions (`kind` attribute) always end in `__gtx_horizontal`. - 11.3: The name of vertical dimensions (`kind` attribute) always end in `__gtx_vertical`. - 11.4: The name of local dimensions always ends in `__gtx_localdim`. @@ -121,7 +121,7 @@ For maps we assume the following: 12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range. - [Rationale 1]: Because of rule 11, we will only fuse maps that actually makes sense to fuse. - - [Rationale 2]: This allows to fusing maps without performing a renaming on the map variables. + - [Rationale 2]: This allows fusing maps without renaming the map variables. - [Note]: This rule might be dropped in the future. ## Consequences diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 75a5aa07e3..583cdf9a8b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -30,6 +30,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, gtir_to_tasklet, + transformations as gtx_transformations, utility as dace_fieldview_util, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -531,5 +532,5 @@ def build_sdfg_from_gtir( # we can remove unnecesssary data connectors (not done by dace simplify pass) sdfg.apply_transformations_repeated(dace_dataflow.PruneConnectors) - sdfg.simplify() + gtx_transformations.gt_simplify(sdfg) return sdfg From 7cb5e3561830ff93a0266e0f17a993c81873c439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 4 Sep 2024 14:39:18 +0200 Subject: [PATCH 230/235] Update docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- .../ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index cdefe0022c..18b9c1f878 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -13,7 +13,7 @@ Their main intent is to reduce the complexity of the GT4Py specific transformati ## Context -The canonical is outlined in this document was mainly designed from the perspective of the optimization pipeline. +The canonical form outlined in this document was mainly designed from the perspective of the optimization pipeline. Thus it emphasizes a form that can be handled in a simple and efficient way by a transformation. In the pipeline we distinguish between: From 04b652ed88cecacaac2425c0827b26351d1982aa Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 18:33:46 +0200 Subject: [PATCH 231/235] Refactored the loop blocking transformation. --- .../transformations/loop_blocking.py | 380 +++++++++--------- 1 file changed, 198 insertions(+), 182 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index 497a365f75..e0ac3ccb90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -59,6 +59,22 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) + independent_nodes = dace_properties.Property( + dtype=set, + allow_none=True, + default=None, + optional=True, + optional_condition=lambda _: False, + desc="Set of nodes that are independent of the blocking parameter.", + ) + dependent_nodes = dace_properties.Property( + dtype=set, + allow_none=True, + default=None, + optional=True, + optional_condition=lambda _: False, + desc="Set of nodes that are dependent on the blocking parameter.", + ) outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) @@ -111,7 +127,7 @@ def can_be_applied( return False if map_range[map_params.index(block_var)][2] != 1: return False - if self.partition_map_output(graph, sdfg) is None: + if not self.partition_map_output(graph, sdfg): return False return True @@ -127,23 +143,25 @@ def apply( """ # Now compute the partitions of the nodes. - independent_nodes, dependent_nodes = self.partition_map_output(graph, sdfg) # type: ignore[misc] # Guaranteed to be not `None`. + self.partition_map_output(graph, sdfg) # type: ignore[misc] # Guaranteed to be not `None`. # Modify the outer map and create the inner map. (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps(graph) # Reconnect the edges - _rewire_map_scope( + self._rewire_map_scope( outer_entry=outer_entry, outer_exit=outer_exit, inner_entry=inner_entry, inner_exit=inner_exit, - independent_nodes=independent_nodes, - dependent_nodes=dependent_nodes, state=graph, sdfg=sdfg, ) + # Clear the old partitions + self.independent_nodes = None + self.dependent_nodes = None + def _prepare_inner_outer_maps( self, state: dace.SDFGState, @@ -211,10 +229,10 @@ def partition_map_output( self, state: dace.SDFGState, sdfg: dace.SDFG, - ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: - """Partition the of the nodes of the Map. + ) -> bool: + """Computes the partition the of the nodes of the Map. - The outputs will be two sets, defined as: + The function divides the nodes into two sets, defined as: - The independent nodes `\mathcal{I}`: These are the nodes, whose output does not depend on the blocked dimension. These nodes can be relocated between the outer and inner map. @@ -240,7 +258,9 @@ def partition_map_output( that are inside the scope and are not classified as dependent or independent are known as "inner nodes". - In case the function fails to compute the partition `None` is returned. + If the function is able to compute the partition `True` is returned and the + member variables are updated. If the partition does not exists the function + will return `False` and the respective member variables will be `None`. Args: state: The state on which we operate. @@ -251,7 +271,10 @@ def partition_map_output( - Currently this function only considers the input Memlets and the `used_symbol` properties of a Tasklet to determine if a Tasklet is dependent. """ - independent_nodes: set[dace_nodes.Node] = set() # `\mathcal{I}` + + # Clear the previous partition. + self.independent_nodes = set() + self.dependent_nodes = None while True: # Find all the nodes that we have to classify in this iteration. @@ -260,23 +283,23 @@ def partition_map_output( nodes_to_classify: set[dace_nodes.Node] = { edge.dst for edge in state.out_edges(self.outer_entry) } - for independent_node in independent_nodes: + for independent_node in self.independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) - nodes_to_classify.difference_update(independent_nodes) + nodes_to_classify.difference_update(self.independent_nodes) # Now classify each node found_new_independent_node = False for node_to_classify in nodes_to_classify: class_res = self._classify_node( node_to_classify=node_to_classify, - independent_nodes=independent_nodes, state=state, sdfg=sdfg, ) # Check if the partition exists. if class_res is None: - return None + self.independent_nodes = None + return False if class_res is True: found_new_independent_node = True @@ -286,18 +309,17 @@ def partition_map_output( # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. - dependent_nodes: set[dace_nodes.Node] = { + self.dependent_nodes = { edge.dst for edge in state.out_edges(self.outer_entry) - if edge.dst not in independent_nodes + if edge.dst not in self.independent_nodes } - return (independent_nodes, dependent_nodes) + return True def _classify_node( self, node_to_classify: dace_nodes.Node, - independent_nodes: set[dace_nodes.Node], state: dace.SDFGState, sdfg: dace.SDFG, ) -> bool | None: @@ -315,15 +337,13 @@ def _classify_node( Returns: The function returns `True` if `node_to_classify` is considered independent. - In this case the function will add the node to `independent_nodes`. + In this case the function will add the node to `self.independent_nodes`. If the function returns `False` the node was classified as a dependent node. The function will return `None` if the node can not be classified, in this case the partition does not exist. Args: node_to_classify: The node that should be classified. - independent_nodes: The set of nodes that was already classified as - independent, in which case it is added to `independent_nodes`. state: The state containing the map. sdfg: The SDFG that is processed. """ @@ -420,183 +440,179 @@ def _classify_node( # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is outer_entry or in_edge.src in independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in self.independent_nodes): return False # Loop ended normally, thus we classify the node as independent. - independent_nodes.add(node_to_classify) + self.independent_nodes.add(node_to_classify) return True + def _rewire_map_scope( + self, + outer_entry: dace_nodes.MapEntry, + outer_exit: dace_nodes.MapExit, + inner_entry: dace_nodes.MapEntry, + inner_exit: dace_nodes.MapExit, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + """Rewire the edges inside the scope defined by the outer map. -def _rewire_map_scope( - outer_entry: dace_nodes.MapEntry, - outer_exit: dace_nodes.MapExit, - inner_entry: dace_nodes.MapEntry, - inner_exit: dace_nodes.MapExit, - independent_nodes: set[dace_nodes.Node], - dependent_nodes: set[dace_nodes.Node], - state: dace.SDFGState, - sdfg: dace.SDFG, -) -> None: - """Rewire the edges inside the scope defined by the outer map. + The function assumes that the outer and inner map were obtained by a call + to `_prepare_inner_outer_maps()`. The function will now rewire the connections of these + nodes such that the dependent nodes are inside the scope of the inner map, + while the independent nodes remain outside. - The function assumes that the outer and inner map were obtained by a call - to `_prepare_inner_outer_maps()`. The function will now rewire the connections of these - nodes such that the dependent nodes are inside the scope of the inner map, - while the independent nodes remain outside. + Args: + outer_entry: The entry node of the outer map. + outer_exit: The exit node of the outer map. + inner_entry: The entry node of the inner map. + inner_exit: The exit node of the inner map. + state: The state of the map. + sdfg: The SDFG we operate on. + """ - Args: - outer_entry: The entry node of the outer map. - outer_exit: The exit node of the outer map. - inner_entry: The entry node of the inner map. - inner_exit: The exit node of the inner map. - independent_nodes: The set of independent nodes. - dependent_nodes: The set of dependent nodes. - state: The state of the map. - sdfg: The SDFG we operate on. - """ + # Contains the nodes that are already have been handled. + relocated_nodes: set[dace_nodes.Node] = set() + + # We now handle all independent nodes, this means that all of their + # _output_ edges have to go through the new inner map and the Memlets need + # modifications, because of the block parameter. + for independent_node in self.independent_nodes: + for out_edge in state.out_edges(independent_node): + edge_dst: dace_nodes.Node = out_edge.dst + relocated_nodes.add(edge_dst) + + # If destination of this edge is also independent we do not need + # to handle it, because that node will also be before the new + # inner serial map. + if edge_dst in self.independent_nodes: + continue + + # Now split `out_edge` such that it passes through the new inner entry. + # We do not need to modify the subsets, i.e. replacing the variable + # on which we block, because the node is independent and the outgoing + # new inner map entry iterate over the blocked variable. + new_map_conn = inner_entry.next_connector() + dace_helpers.redirect_edge( + state=state, + edge=out_edge, + new_dst=inner_entry, + new_dst_conn="IN_" + new_map_conn, + ) + # TODO(phimuell): Check if there might be a subset error. + state.add_edge( + inner_entry, + "OUT_" + new_map_conn, + out_edge.dst, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + inner_entry.add_in_connector("IN_" + new_map_conn) + inner_entry.add_out_connector("OUT_" + new_map_conn) + + # Now we handle the dependent nodes, they differ from the independent nodes + # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. + for dependent_node in self.dependent_nodes: + for in_edge in state.in_edges(dependent_node): + edge_src: dace_nodes.Node = in_edge.src + + # Since the independent nodes were already processed, and they process + # their output we have to check for this. We do this by checking if + # the source of the edge is the new inner map entry. + if edge_src is inner_entry: + assert dependent_node in relocated_nodes + continue + + # A dependent node has at least one connection to the outer map entry. + # And these are the only connections that we must handle, since other + # connections come from independent nodes, and were already handled + # or are inner nodes. + if edge_src is not outer_entry: + continue + + # If we encounter an empty Memlet we just just attach it to the + # new inner map entry. Note the partition function ensures that + # either all edges are empty or non. + if in_edge.data.is_empty(): + assert ( + edge_src is outer_entry + ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) + continue + + # Because of the definition of a dependent node and the processing + # order, their incoming edges either point to the outer map or + # are already handled. + assert ( + edge_src is outer_entry + ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." + edge_conn: str = in_edge.src_conn[4:] - # Contains the nodes that are already have been handled. - relocated_nodes: set[dace_nodes.Node] = set() - - # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in independent_nodes: - for out_edge in state.out_edges(independent_node): - edge_dst: dace_nodes.Node = out_edge.dst - relocated_nodes.add(edge_dst) - - # If destination of this edge is also independent we do not need - # to handle it, because that node will also be before the new - # inner serial map. - if edge_dst in independent_nodes: - continue + # Must be before the handling of the modification below + # Note that this will remove the original edge from the SDFG. + dace_helpers.redirect_edge( + state=state, + edge=in_edge, + new_src=inner_entry, + new_src_conn="OUT_" + edge_conn, + ) - # Now split `out_edge` such that it passes through the new inner entry. - # We do not need to modify the subsets, i.e. replacing the variable - # on which we block, because the node is independent and the outgoing - # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, - ) - # TODO(phimuell): Check if there might be a subset error. + # In a valid SDFG only one edge can go into an input connector of a Map. + if "IN_" + edge_conn in inner_entry.in_connectors: + # We have found this edge multiple times already. + # To ensure that there is no error, we will create a new + # Memlet that reads the whole array. + piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) + data_name = piping_edge.data.data + piping_edge.data = dace.Memlet.from_array( + data_name, sdfg.arrays[data_name], piping_edge.data.wcr + ) + + else: + # This is the first time we found this connection. + # so we just create the edge. + state.add_edge( + outer_entry, + "OUT_" + edge_conn, + inner_entry, + "IN_" + edge_conn, + copy.deepcopy(in_edge.data), + ) + inner_entry.add_in_connector("IN_" + edge_conn) + inner_entry.add_out_connector("OUT_" + edge_conn) + + # In certain cases it might happen that we need to create an empty + # Memlet between the outer map entry and the inner one. + if state.in_degree(inner_entry) == 0: state.add_edge( + outer_entry, + None, inner_entry, - "OUT_" + new_map_conn, - out_edge.dst, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), + None, + dace.Memlet(), ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) - - # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in dependent_nodes: - for in_edge in state.in_edges(dependent_node): - edge_src: dace_nodes.Node = in_edge.src - - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. - if edge_src is inner_entry: - assert dependent_node in relocated_nodes - continue - - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." - dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. + # Handle the Map exits + # This is simple reconnecting, there would be possibilities for improvements + # but we do not use them for now. + for in_edge in state.in_edges(outer_exit): + edge_conn = in_edge.dst_conn[3:] dace_helpers.redirect_edge( state=state, edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, + new_dst=inner_exit, + new_dst_conn="IN_" + edge_conn, ) + state.add_edge( + inner_exit, + "OUT_" + edge_conn, + outer_exit, + in_edge.dst_conn, + copy.deepcopy(in_edge.data), + ) + inner_exit.add_in_connector("IN_" + edge_conn) + inner_exit.add_out_connector("OUT_" + edge_conn) - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr - ) - - else: - # This is the first time we found this connection. - # so we just create the edge. - state.add_edge( - outer_entry, - "OUT_" + edge_conn, - inner_entry, - "IN_" + edge_conn, - copy.deepcopy(in_edge.data), - ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) - - # In certain cases it might happen that we need to create an empty - # Memlet between the outer map entry and the inner one. - if state.in_degree(inner_entry) == 0: - state.add_edge( - outer_entry, - None, - inner_entry, - None, - dace.Memlet(), - ) - - # Handle the Map exits - # This is simple reconnecting, there would be possibilities for improvements - # but we do not use them for now. - for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_dst=inner_exit, - new_dst_conn="IN_" + edge_conn, - ) - state.add_edge( - inner_exit, - "OUT_" + edge_conn, - outer_exit, - in_edge.dst_conn, - copy.deepcopy(in_edge.data), - ) - inner_exit.add_in_connector("IN_" + edge_conn) - inner_exit.add_out_connector("OUT_" + edge_conn) - - # TODO(phimuell): Use a less expensive method. - dace.sdfg.propagation.propagate_memlets_state(sdfg, state) + # TODO(phimuell): Use a less expensive method. + dace.sdfg.propagation.propagate_memlets_state(sdfg, state) From e4df5ae29addb70971035ccb7269795a773912c9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 07:26:41 +0200 Subject: [PATCH 232/235] Fixed an merge issue with master. --- .../program_processors/runners/dace_fieldview/gtir_to_sdfg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 311afb66a9..99388336e9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -535,9 +535,5 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) - # nested-SDFGs for let-lambda may contain unused symbols, in which case - # we can remove unnecesssary data connectors (not done by dace simplify pass) - sdfg.apply_transformations_repeated(dace_dataflow.PruneConnectors) - gtx_transformations.gt_simplify(sdfg) return sdfg From 7265ecc3eea1b70f00a55a438b178a41f58e7fe1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 07:26:51 +0200 Subject: [PATCH 233/235] Fixed an issue related to the refactoring yesterday evening. --- .../runners/dace_fieldview/transformations/loop_blocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index e0ac3ccb90..7acd997a0d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -143,7 +143,7 @@ def apply( """ # Now compute the partitions of the nodes. - self.partition_map_output(graph, sdfg) # type: ignore[misc] # Guaranteed to be not `None`. + self.partition_map_output(graph, sdfg) # Modify the outer map and create the inner map. (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps(graph) From 5ab199d70335a65fa1177aaec082a79c892eb66f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 10:49:56 +0200 Subject: [PATCH 234/235] Something is fishy. The error is the DaCe loading code. Let's use only general names to see what is happening. --- .../dace/transformation_tests/test_gpu_utils.py | 3 ++- .../dace/transformation_tests/test_loop_blocking.py | 5 +++-- .../dace/transformation_tests/test_map_fusion.py | 7 ++++--- .../dace/transformation_tests/test_serial_map_promoter.py | 3 ++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py index 8b49a430d4..1524af8151 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py @@ -11,6 +11,7 @@ import copy import numpy as np import pytest +import time from dace.sdfg import nodes as dace_nodes @@ -42,7 +43,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. """ - sdfg = dace.SDFG("TrivailGPUPromotable") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py index 812225d273..c472cca9d3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py @@ -11,6 +11,7 @@ import copy import numpy as np import pytest +import time from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation @@ -28,7 +29,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np can be taken out. This is because how it is constructed. However, applying some simplistic transformations this can be done. """ - sdfg = dace.SDFG("only_dependent") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -51,7 +52,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n The bottom Tasklet is the only dependent Tasklet. """ - sdfg = dace.SDFG("only_dependent") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index 1c43e00efd..c0019cc48c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -12,6 +12,7 @@ import dace import copy import numpy as np +import time from dace.sdfg import nodes as dace_nodes from dace.transformation import dataflow as dace_dataflow @@ -38,7 +39,7 @@ def _make_serial_sdfg_1( N: The size of the arrays. """ shape = (N, N) - sdfg = dace.SDFG("serial_1_sdfg") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "tmp"]: @@ -93,7 +94,7 @@ def _make_serial_sdfg_2( N: The size of the arrays. """ shape = (N, N) - sdfg = dace.SDFG("serial_2_sdfg") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "tmp_1", "tmp_2"]: @@ -165,7 +166,7 @@ def _make_serial_sdfg_3( input_shape = (N_input,) output_shape = (N_output,) - sdfg = dace.SDFG("serial_3_sdfg") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state(is_start_block=True) for name, shape in [ diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index c483866a35..3727588c1c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -11,6 +11,7 @@ import copy import numpy as np import pytest +import time from dace.sdfg import nodes as dace_nodes @@ -28,7 +29,7 @@ def test_serial_map_promotion(): N = 10 shape_1d = (N,) shape_2d = (N, N) - sdfg = dace.SDFG("serial_promotable") + sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") state = sdfg.add_state(is_start_block=True) # 1D Arrays From a0866a7e66b937ebb96061d279f9b8371060c865 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:09:20 +0200 Subject: [PATCH 235/235] Switched to UUID from time. Also added a helper function. --- .../dace/transformation_tests/test_gpu_utils.py | 3 +-- .../dace/transformation_tests/test_loop_blocking.py | 6 +++--- .../dace/transformation_tests/test_map_fusion.py | 7 +++---- .../transformation_tests/test_serial_map_promoter.py | 3 +-- .../runners_tests/dace/transformation_tests/util.py | 10 ++++++---- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py index 1524af8151..d51c242977 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_gpu_utils.py @@ -11,7 +11,6 @@ import copy import numpy as np import pytest -import time from dace.sdfg import nodes as dace_nodes @@ -43,7 +42,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. """ - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py index c472cca9d3..b4fde413a1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_loop_blocking.py @@ -11,13 +11,13 @@ import copy import numpy as np import pytest -import time from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) +from . import util pytestmark = pytest.mark.usefixtures("set_dace_settings") @@ -29,7 +29,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np can be taken out. This is because how it is constructed. However, applying some simplistic transformations this can be done. """ - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -52,7 +52,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n The bottom Tasklet is the only dependent Tasklet. """ - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("chained_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py index c0019cc48c..54319b27dd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion.py @@ -12,7 +12,6 @@ import dace import copy import numpy as np -import time from dace.sdfg import nodes as dace_nodes from dace.transformation import dataflow as dace_dataflow @@ -39,7 +38,7 @@ def _make_serial_sdfg_1( N: The size of the arrays. """ shape = (N, N) - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("serial_sdfg1")) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "tmp"]: @@ -94,7 +93,7 @@ def _make_serial_sdfg_2( N: The size of the arrays. """ shape = (N, N) - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("serial_sdfg2")) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "tmp_1", "tmp_2"]: @@ -166,7 +165,7 @@ def _make_serial_sdfg_3( input_shape = (N_input,) output_shape = (N_output,) - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("serial_sdfg3")) state = sdfg.add_state(is_start_block=True) for name, shape in [ diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py index 3727588c1c..85846ef6c8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -11,7 +11,6 @@ import copy import numpy as np import pytest -import time from dace.sdfg import nodes as dace_nodes @@ -29,7 +28,7 @@ def test_serial_map_promotion(): N = 10 shape_1d = (N,) shape_2d = (N, N) - sdfg = dace.SDFG(f"test_sdfg__{int(time.time() * 1000)}") + sdfg = dace.SDFG(util.unique_name("serial_promotable_sdfg")) state = sdfg.add_state(is_start_block=True) # 1D Arrays diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py index 739582d5d9..0caabcc7be 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py @@ -7,15 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from typing import Union, Literal, overload +import uuid import dace from dace.sdfg import nodes as dace_nodes from dace.transformation import dataflow as dace_dataflow -__all__ = [ - "_count_nodes", -] - @overload def _count_nodes( @@ -57,3 +54,8 @@ def _count_nodes( if return_nodes: return found_nodes return len(found_nodes) + + +def unique_name(name: str) -> str: + """Adds a unique string to `name`.""" + return f"{name}_{str(uuid.uuid1()).replace('-', '_')}"