diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 0c9f68a57f..c0063e82d5 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -173,14 +173,13 @@ class SimpleTemporaryExtractionHeuristics: closure: ir.StencilClosure - @functools.cached_property - def closure_shifts( - self, - ) -> dict[int, set[tuple[ir.OffsetLiteral, ...]]]: - return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) # type: ignore[return-value] # TODO fix weird `apply` overloads + def __post_init__(self) -> None: + trace_shifts.trace_stencil( + self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True + ) def __call__(self, expr: ir.Expr) -> bool: - shifts = self.closure_shifts[id(expr)] + shifts = expr.annex.recorded_shifts if len(shifts) > 1: return True return False @@ -564,8 +563,9 @@ def update_domains( closures.append(closure) - local_shifts = trace_shifts.TraceShifts.apply(closure) - for param, shift_chains in local_shifts.items(): + local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs)) + for param_sym, shift_chains in zip(closure.inputs, local_shifts): + param = param_sym.id assert isinstance(param, str) consumed_domains: list[SymbolicDomain] = ( [SymbolicDomain.from_expr(domains[param])] if param in domains else [] diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 2d465eb3b0..e05c58e157 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -11,8 +11,8 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union -from gt4py.next.iterator.transforms.trace_shifts import TraceShifts def _merge_domains( @@ -28,19 +28,6 @@ def _merge_domains( return new_domains -# FIXME[#1582](tehrengruber): Use new TraceShift API when #1592 is merged. -def trace_shifts( - stencil: itir.Expr, input_ids: list[str], domain: itir.Expr -) -> dict[str, set[tuple[itir.OffsetLiteral, ...]]]: - node = itir.StencilClosure( - stencil=stencil, - inputs=[im.ref(id_) for id_ in input_ids], - output=im.ref("__dummy"), - domain=domain, - ) - return TraceShifts.apply(node, inputs_only=True) # type: ignore[return-value] # ensured by inputs_only=True - - def extract_shifts_and_translate_domains( stencil: itir.Expr, input_ids: list[str], @@ -48,11 +35,9 @@ def extract_shifts_and_translate_domains( offset_provider: Dict[str, Dimension], accessed_domains: Dict[str, SymbolicDomain], ): - shifts_results = trace_shifts(stencil, input_ids, SymbolicDomain.as_expr(target_domain)) - - for in_field_id in input_ids: - shifts_list = shifts_results[in_field_id] + shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) + for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list ] diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index cdea1a7a48..5def306bbc 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -14,9 +14,9 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.trace_shifts import TraceShifts, copy_recorded_shifts def is_center_derefed_only(node: itir.Node) -> bool: @@ -58,7 +58,7 @@ def apply(cls, node: itir.FencilDefinition, uids: Optional[eve_utils.UIDGenerato def visit_StencilClosure(self, node: itir.StencilClosure, **kwargs): # TODO(tehrengruber): move the analysis out of this pass and just make it a requirement # such that we don't need to run in multiple times if multiple passes use it. - TraceShifts.apply(node, save_to_annex=True) + trace_shifts.trace_stencil(node.stencil, num_args=len(node.inputs), save_to_annex=True) return self.generic_visit(node, **kwargs) def visit_FunCall(self, node: itir.FunCall, **kwargs): @@ -74,7 +74,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): eligible_params[i] = True bound_arg_name = self.uids.sequential_id(prefix="_icdlv") capture_lift = im.promote_to_const_iterator(bound_arg_name) - copy_recorded_shifts(from_=param, to=capture_lift) + trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) new_args.append(capture_lift) # since we deref an applied lift here we can (but don't need to) immediately # inline diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index d181940b1d..68346b6622 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -9,11 +9,12 @@ import dataclasses import sys from collections.abc import Callable -from typing import Any, Final, Iterable, Literal +from typing import Any, Final, Iterable, Literal, Optional from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -76,7 +77,7 @@ def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]) # for performance reasons (`isinstance` is slow otherwise) we don't use abc here -class IteratorTracer: +class Tracer: def deref(self): raise NotImplementedError() @@ -85,13 +86,13 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): @dataclasses.dataclass(frozen=True) -class IteratorArgTracer(IteratorTracer): +class ArgTracer(Tracer): arg: ir.Expr | ir.Sym shift_recorder: ShiftRecorder | ForwardingShiftRecorder offsets: tuple[ir.OffsetLiteral, ...] = () def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): - return IteratorArgTracer( + return ArgTracer( arg=self.arg, shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets) ) @@ -103,8 +104,8 @@ def deref(self): # This class is only needed because we currently allow conditionals on iterators. Since this is # not supported in the C++ backend it can likely be removed again in the future. @dataclasses.dataclass(frozen=True) -class CombinedTracer(IteratorTracer): - its: tuple[IteratorTracer, ...] +class CombinedTracer(Tracer): + its: tuple[Tracer, ...] def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): return CombinedTracer(tuple(_shift(*offsets)(it) for it in self.its)) @@ -142,16 +143,16 @@ def _shift(*offsets): ) def apply(arg): - assert isinstance(arg, IteratorTracer) + assert isinstance(arg, Tracer) return arg.shift(offsets) return apply @dataclasses.dataclass(frozen=True) -class AppliedLift(IteratorTracer): +class AppliedLift(Tracer): stencil: Callable - its: tuple[IteratorTracer, ...] + its: tuple[Tracer, ...] def shift(self, offsets): return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its)) @@ -162,7 +163,7 @@ def deref(self): def _lift(f): def apply(*its): - if not all(isinstance(it, IteratorTracer) for it in its): + if not all(isinstance(it, Tracer) for it in its): raise AssertionError("All arguments must be iterators.") return AppliedLift(f, its) @@ -189,20 +190,18 @@ def apply(*args): def _primitive_constituents( - val: Literal[Sentinel.VALUE] | IteratorTracer | tuple, -) -> Iterable[Literal[Sentinel.VALUE] | IteratorTracer]: - if val is Sentinel.VALUE or isinstance(val, IteratorTracer): + val: Literal[Sentinel.VALUE] | Tracer | tuple, +) -> Iterable[Literal[Sentinel.VALUE] | Tracer]: + if val is Sentinel.VALUE or isinstance(val, Tracer): yield val elif isinstance(val, tuple): for el in val: if isinstance(el, tuple): yield from _primitive_constituents(el) - elif el is Sentinel.VALUE or isinstance(el, IteratorTracer): + elif el is Sentinel.VALUE or isinstance(el, Tracer): yield el else: - raise AssertionError( - "Expected a `Sentinel.VALUE`, `IteratorTracer` or tuple thereof." - ) + raise AssertionError("Expected a `Sentinel.VALUE`, `Tracer` or tuple thereof.") else: raise ValueError() @@ -225,9 +224,7 @@ def _if(cond: Literal[Sentinel.VALUE], true_branch, false_branch): result.append(_if(Sentinel.VALUE, el_true_branch, el_false_branch)) return tuple(result) - is_iterator_arg = tuple( - isinstance(arg, IteratorTracer) for arg in (cond, true_branch, false_branch) - ) + is_iterator_arg = tuple(isinstance(arg, Tracer) for arg in (cond, true_branch, false_branch)) if is_iterator_arg == (False, True, True): return CombinedTracer((true_branch, false_branch)) assert is_iterator_arg == (False, False, False) and all( @@ -247,7 +244,15 @@ def _tuple_get(index, tuple_val): return Sentinel.VALUE +def _as_fieldop(stencil, domain=None): + def applied_as_fieldop(*args): + return stencil(*args) + + return applied_as_fieldop + + _START_CTX: Final = { + "as_fieldop": _as_fieldop, "deref": _deref, "can_deref": _can_deref, "shift": _shift, @@ -291,11 +296,11 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any: def visit(self, node, **kwargs): result = super().visit(node, **kwargs) - if isinstance(result, IteratorTracer): + if isinstance(result, Tracer): assert isinstance(node, (ir.Sym, ir.Expr)) self.shift_recorder.register_node(node) - result = IteratorArgTracer( + result = ArgTracer( arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder) ) return result @@ -304,10 +309,10 @@ def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable: def fun(*args): new_args = [] for param, arg in zip(node.params, args, strict=True): - if isinstance(arg, IteratorTracer): + if isinstance(arg, Tracer): self.shift_recorder.register_node(param) new_args.append( - IteratorArgTracer( + ArgTracer( arg=param, shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder), ) @@ -321,46 +326,49 @@ def fun(*args): return fun - def visit_StencilClosure(self, node: ir.StencilClosure): - tracers = [] - for inp in node.inputs: - self.shift_recorder.register_node(inp) - tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder)) - - result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) - assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) - return node - @classmethod - def apply( - cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False - ) -> ( - dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] + def trace_stencil( + cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False ): + # If we get a lambda we can deduce the number of arguments. + if isinstance(stencil, ir.Lambda): + assert num_args is None or num_args == len(stencil.params) + num_args = len(stencil.params) + if not isinstance(num_args, int): + raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + assert isinstance(num_args, int) + + args = [im.ref(f"__arg{i}") for i in range(num_args)] + old_recursionlimit = sys.getrecursionlimit() sys.setrecursionlimit(100000000) instance = cls() - instance.visit(node) + + # initialize shift recorder & context with all built-ins and the iterator argument tracers + ctx: dict[str, Any] = {**_START_CTX} + for arg in args: + instance.shift_recorder.register_node(arg) + ctx[arg.id] = ArgTracer(arg=arg, shift_recorder=instance.shift_recorder) + + # actually trace stencil + instance.visit(im.call(stencil)(*args), ctx=ctx) sys.setrecursionlimit(old_recursionlimit) recorded_shifts = instance.shift_recorder.recorded_shifts + param_shifts = [] + for arg in args: + param_shifts.append(recorded_shifts[id(arg)]) + if save_to_annex: - _save_to_annex(node, recorded_shifts) + _save_to_annex(stencil, recorded_shifts) - if __debug__: - ValidateRecordedShiftsAnnex().visit(node) + return param_shifts - if inputs_only: - assert isinstance(node, ir.StencilClosure) - inputs_shifts = {} - for inp in node.inputs: - inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)] - return inputs_shifts - return recorded_shifts +trace_stencil = TraceShifts.trace_stencil def _save_to_annex( @@ -369,3 +377,6 @@ def _save_to_annex( for child_node in node.pre_walk_values(): if id(child_node) in recorded_shifts: child_node.annex.recorded_shifts = recorded_shifts[id(child_node)] + + if __debug__: + ValidateRecordedShiftsAnnex().visit(node) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 300fc5bdaa..da62ebfe92 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -395,13 +395,15 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: self.itir, offset_provider=offset_provider ) for closure in itir_tmp.closures: # type: ignore[union-attr] - shifts = itir_transforms.trace_shifts.TraceShifts.apply(closure) - for k, v in shifts.items(): - if not isinstance(k, str): + params_shifts = itir_transforms.trace_shifts.trace_stencil( + closure.stencil, num_args=len(closure.inputs) + ) + for param, shifts in zip(closure.inputs, params_shifts): + if not isinstance(param.id, str): continue - if k not in sdfg.gt4py_program_input_fields: + if param.id not in sdfg.gt4py_program_input_fields: continue - sdfg.offset_providers_per_input_field.setdefault(k, []).extend(list(v)) + sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) return sdfg diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index d2dc3f6053..1cf662e221 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -11,163 +11,77 @@ from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts -def test_trivial(): - testee = ir.StencilClosure( - stencil=ir.SymRef(id="deref"), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {()}} +def test_trivial_stencil(): + expected = [{()}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(im.ref("deref"), num_args=1) assert actual == expected def test_shift(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], - ), - args=[ir.SymRef(id="x")], - ) - ], - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + testee = im.lambda_("inp")(im.deref(im.shift("I", 1)("inp"))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_lift(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], - ), - args=[ir.SymRef(id="x")], - ) - ], - ) - ], - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + testee = im.lambda_("inp")(im.deref(im.lift("deref")(im.shift("I", 1)("inp")))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_neighbors(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="neighbors"), args=[ir.OffsetLiteral(value="O"), ir.SymRef(id="x")] - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="O"), Sentinel.ALL_NEIGHBORS)}} + testee = im.lambda_("inp")(im.neighbors("O", "inp")) + expected = [{(ir.OffsetLiteral(value="O"), Sentinel.ALL_NEIGHBORS)}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_reduce(): - testee = ir.StencilClosure( - # λ(inp) → reduce(plus, 0.)(·inp) - stencil=ir.Lambda( - params=[ir.Sym(id="inp")], - expr=ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="plus"), im.literal_from_value(0.0)], - ), - args=[ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="inp")])], - ), - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {()}} + # λ(inp) → reduce(plus, 0.)(·inp) + testee = im.lambda_("inp")(im.call(im.call("reduce")("plus", 0.0))(im.deref("inp"))) + expected = [{()}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_shifted_literal(): "Test shifting an applied lift of a stencil returning a constant / literal works." - testee = ir.StencilClosure( - # λ(x) → ·⟪Iₒ, 1ₒ⟫((↑(λ() → 1))()) - stencil=im.lambda_("x")(im.deref(im.shift("I", 1)(im.lift(im.lambda_()(1))()))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": set()} + testee = im.lambda_("inp")(im.deref(im.shift("I", 1)(im.lift(im.lambda_()(1))()))) + expected = [set()] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_tuple_get(): - testee = ir.StencilClosure( - # λ(x, y) → ·{x, y}[1] - stencil=im.lambda_("x", "y")(im.deref(im.tuple_get(1, im.make_tuple("x", "y")))), - inputs=[ir.SymRef(id="inp1"), ir.SymRef(id="inp2")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp1": set(), "inp2": {()}} # never derefed # once derefed - - actual = TraceShifts.apply(testee) + # λ(x, y) → ·{x, y}[1] + testee = im.lambda_("x", "y")(im.deref(im.tuple_get(1, im.make_tuple("x", "y")))) + expected = [ + set(), # never derefed + {()}, # once derefed + ] + + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_trace_non_closure_input_arg(): x, y = im.sym("x"), im.sym("y") - testee = ir.StencilClosure( - # λ(x) → (λ(y) → ·⟪Iₒ, 1ₒ⟫(y))(⟪Iₒ, 2ₒ⟫(x)) - stencil=im.lambda_(x)( - im.call(im.lambda_(y)(im.deref(im.shift("I", 1)("y"))))(im.shift("I", 2)("x")) - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(x) → (λ(y) → ·⟪Iₒ, 1ₒ⟫(y))(⟪Iₒ, 2ₒ⟫(x)) + testee = im.lambda_(x)( + im.call(im.lambda_(y)(im.deref(im.shift("I", 1)("y"))))(im.shift("I", 2)("x")) ) - actual = TraceShifts.apply(testee, inputs_only=False) + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) - assert actual[id(x)] == { + assert x.annex.recorded_shifts == { ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -175,78 +89,59 @@ def test_trace_non_closure_input_arg(): ir.OffsetLiteral(value=1), ) } - assert actual[id(y)] == {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} + assert y.annex.recorded_shifts == {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} def test_inner_iterator(): inner_shift = im.shift("I", 1)("x") - testee = ir.StencilClosure( - # λ(x) → ·⟪Iₒ, 1ₒ⟫(⟪Iₒ, 1ₒ⟫(x)) - stencil=im.lambda_("x")(im.deref(im.shift("I", 1)(inner_shift))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) + # λ(x) → ·⟪Iₒ, 1ₒ⟫(⟪Iₒ, 1ₒ⟫(x)) + testee = im.lambda_("x")(im.deref(im.shift("I", 1)(inner_shift))) expected = {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} - actual = TraceShifts.apply(testee, inputs_only=False) - assert actual[id(inner_shift)] == expected + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) + assert inner_shift.annex.recorded_shifts == expected def test_tuple_get_on_closure_input(): - testee = ir.StencilClosure( - # λ(x) → (·⟪Iₒ, 1ₒ⟫(x))[0] - stencil=im.lambda_("x")(im.tuple_get(0, im.deref(im.shift("I", 1)("x")))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + # λ(x) → (·⟪Iₒ, 1ₒ⟫(x))[0] + testee = im.lambda_("x")(im.tuple_get(0, im.deref(im.shift("I", 1)("x")))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_tuple_branch_broadcasting(): - testee = ir.StencilClosure( - # λ(cond, inp) → (if ·cond then ·inp else {1, 2})[1] - stencil=im.lambda_("cond", "inp")( - im.tuple_get( - 1, - im.if_( - im.deref("cond"), - im.deref("inp"), - im.make_tuple(im.literal_from_value(1), im.literal_from_value(2)), - ), - ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(cond, inp) → (if ·cond then ·inp else {1, 2})[1] + testee = im.lambda_("cond", "inp")( + im.tuple_get( + 1, + im.if_( + im.deref("cond"), + im.deref("inp"), + im.make_tuple(im.literal_from_value(1), im.literal_from_value(2)), + ), + ) ) - expected = {"cond": {()}, "inp": {()}} + expected = [ + {()}, # cond + {()}, # inp + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_of_iterators(): - testee = ir.StencilClosure( - # λ(cond, x) → ·⟪Iₒ, 1ₒ⟫(if ·cond then ⟪Iₒ, 2ₒ⟫(x) else ⟪Iₒ, 3ₒ⟫(x)) - stencil=im.lambda_("cond", "x")( - im.deref( - im.shift("I", 1)( - im.if_(im.deref("cond"), im.shift("I", 2)("x"), im.shift("I", 3)("x")) - ) - ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(cond, x) → ·⟪Iₒ, 1ₒ⟫(if ·cond then ⟪Iₒ, 2ₒ⟫(x) else ⟪Iₒ, 3ₒ⟫(x)) + testee = im.lambda_("cond", "x")( + im.deref( + im.shift("I", 1)(im.if_(im.deref("cond"), im.shift("I", 2)("x"), im.shift("I", 3)("x"))) + ) ) - expected = { - "cond": {()}, - "inp": { + expected = [ + {()}, # cond + { # inp ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -260,37 +155,32 @@ def test_if_of_iterators(): ir.OffsetLiteral(value=1), ), }, - } + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_of_tuples_of_iterators(): - testee = ir.StencilClosure( - # λ(cond, x) → - # ·⟪Iₒ, 1ₒ⟫((if ·cond then {⟪Iₒ, 2ₒ⟫(x), ⟪Iₒ, 3ₒ⟫(x)} else {⟪Iₒ, 4ₒ⟫(x), ⟪Iₒ, 5ₒ⟫(x)})[0]) - stencil=im.lambda_("cond", "x")( - im.deref( - im.shift("I", 1)( - im.tuple_get( - 0, - im.if_( - im.deref("cond"), - im.make_tuple(im.shift("I", 2)("x"), im.shift("I", 3)("x")), - im.make_tuple(im.shift("I", 4)("x"), im.shift("I", 5)("x")), - ), - ) + # λ(cond, x) → + # ·⟪Iₒ, 1ₒ⟫((if ·cond then {⟪Iₒ, 2ₒ⟫(x), ⟪Iₒ, 3ₒ⟫(x)} else {⟪Iₒ, 4ₒ⟫(x), ⟪Iₒ, 5ₒ⟫(x)})[0]) + testee = im.lambda_("cond", "x")( + im.deref( + im.shift("I", 1)( + im.tuple_get( + 0, + im.if_( + im.deref("cond"), + im.make_tuple(im.shift("I", 2)("x"), im.shift("I", 3)("x")), + im.make_tuple(im.shift("I", 4)("x"), im.shift("I", 5)("x")), + ), ) ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + ) ) - expected = { - "cond": {()}, - "inp": { + expected = [ + {()}, # cond + { # inp ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -304,9 +194,9 @@ def test_if_of_tuples_of_iterators(): ir.OffsetLiteral(value=1), ), }, - } + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected @@ -315,13 +205,8 @@ def test_non_derefed_iterator(): Test that even if an iterator is not derefed the resulting dict has an (empty) entry for it. """ non_derefed_it = im.shift("I", 1)("x") - testee = ir.StencilClosure( - # λ(x) → (λ(non_derefed_it) → ·x)(⟪Iₒ, 1ₒ⟫(x)) - stencil=im.lambda_("x")(im.let("non_derefed_it", non_derefed_it)(im.deref("x"))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) + # λ(x) → (λ(non_derefed_it) → ·x)(⟪Iₒ, 1ₒ⟫(x)) + testee = im.lambda_("x")(im.let("non_derefed_it", non_derefed_it)(im.deref("x"))) - actual = TraceShifts.apply(testee, inputs_only=False) - assert actual[id(non_derefed_it)] == set() + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) + assert non_derefed_it.annex.recorded_shifts == set()