-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Prepare TraceShift pass for GTIR #1592
Changes from 5 commits
8785f39
d0808d8
5a973e0
6ff3338
b180be9
2652d82
8f5f7dc
77c900b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,12 @@ | |
import enum | ||
import sys | ||
from collections.abc import Callable | ||
from typing import Any, Final, Iterable, Literal | ||
from typing import Any, Final, Iterable, Literal, Optional | ||
|
||
from gt4py import eve | ||
from gt4py.eve import NodeTranslator, PreserveLocationVisitor | ||
from gt4py.next.iterator import ir | ||
from gt4py.next.iterator.ir_utils import ir_makers as im | ||
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift | ||
|
||
|
||
|
@@ -82,7 +83,7 @@ def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]) | |
|
||
|
||
# for performance reasons (`isinstance` is slow otherwise) we don't use abc here | ||
class IteratorTracer: | ||
class Tracer: | ||
def deref(self): | ||
raise NotImplementedError() | ||
|
||
|
@@ -91,13 +92,13 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): | |
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class IteratorArgTracer(IteratorTracer): | ||
class ArgTracer(Tracer): | ||
arg: ir.Expr | ir.Sym | ||
shift_recorder: ShiftRecorder | ForwardingShiftRecorder | ||
offsets: tuple[ir.OffsetLiteral, ...] = () | ||
|
||
def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): | ||
return IteratorArgTracer( | ||
return ArgTracer( | ||
arg=self.arg, shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets) | ||
) | ||
|
||
|
@@ -109,8 +110,8 @@ def deref(self): | |
# This class is only needed because we currently allow conditionals on iterators. Since this is | ||
# not supported in the C++ backend it can likely be removed again in the future. | ||
@dataclasses.dataclass(frozen=True) | ||
class CombinedTracer(IteratorTracer): | ||
its: tuple[IteratorTracer, ...] | ||
class CombinedTracer(Tracer): | ||
its: tuple[Tracer, ...] | ||
|
||
def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): | ||
return CombinedTracer(tuple(_shift(*offsets)(it) for it in self.its)) | ||
|
@@ -143,16 +144,16 @@ def _can_deref(x): | |
|
||
def _shift(*offsets): | ||
def apply(arg): | ||
assert isinstance(arg, IteratorTracer) | ||
assert isinstance(arg, Tracer) | ||
return arg.shift(offsets) | ||
|
||
return apply | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class AppliedLift(IteratorTracer): | ||
class AppliedLift(Tracer): | ||
stencil: Callable | ||
its: tuple[IteratorTracer, ...] | ||
its: tuple[Tracer, ...] | ||
|
||
def shift(self, offsets): | ||
return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its)) | ||
|
@@ -163,7 +164,7 @@ def deref(self): | |
|
||
def _lift(f): | ||
def apply(*its): | ||
if not all(isinstance(it, IteratorTracer) for it in its): | ||
if not all(isinstance(it, Tracer) for it in its): | ||
raise AssertionError("All arguments must be iterators.") | ||
return AppliedLift(f, its) | ||
|
||
|
@@ -190,20 +191,18 @@ def apply(*args): | |
|
||
|
||
def _primitive_constituents( | ||
val: Literal[Sentinel.VALUE] | IteratorTracer | tuple, | ||
) -> Iterable[Literal[Sentinel.VALUE] | IteratorTracer]: | ||
if val is Sentinel.VALUE or isinstance(val, IteratorTracer): | ||
val: Literal[Sentinel.VALUE] | Tracer | tuple, | ||
) -> Iterable[Literal[Sentinel.VALUE] | Tracer]: | ||
if val is Sentinel.VALUE or isinstance(val, Tracer): | ||
yield val | ||
elif isinstance(val, tuple): | ||
for el in val: | ||
if isinstance(el, tuple): | ||
yield from _primitive_constituents(el) | ||
elif el is Sentinel.VALUE or isinstance(el, IteratorTracer): | ||
elif el is Sentinel.VALUE or isinstance(el, Tracer): | ||
yield el | ||
else: | ||
raise AssertionError( | ||
"Expected a `Sentinel.VALUE`, `IteratorTracer` or tuple thereof." | ||
) | ||
raise AssertionError("Expected a `Sentinel.VALUE`, `Tracer` or tuple thereof.") | ||
else: | ||
raise ValueError() | ||
|
||
|
@@ -226,9 +225,7 @@ def _if(cond: Literal[Sentinel.VALUE], true_branch, false_branch): | |
result.append(_if(Sentinel.VALUE, el_true_branch, el_false_branch)) | ||
return tuple(result) | ||
|
||
is_iterator_arg = tuple( | ||
isinstance(arg, IteratorTracer) for arg in (cond, true_branch, false_branch) | ||
) | ||
is_iterator_arg = tuple(isinstance(arg, Tracer) for arg in (cond, true_branch, false_branch)) | ||
if is_iterator_arg == (False, True, True): | ||
return CombinedTracer((true_branch, false_branch)) | ||
assert is_iterator_arg == (False, False, False) and all( | ||
|
@@ -248,7 +245,15 @@ def _tuple_get(index, tuple_val): | |
return Sentinel.VALUE | ||
|
||
|
||
def _as_fieldop(stencil, domain=None): | ||
def applied_as_fieldop(*args): | ||
return stencil(*args) | ||
|
||
return applied_as_fieldop | ||
|
||
|
||
_START_CTX: Final = { | ||
"as_fieldop": _as_fieldop, | ||
"deref": _deref, | ||
"can_deref": _can_deref, | ||
"shift": _shift, | ||
|
@@ -292,11 +297,11 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any: | |
|
||
def visit(self, node, **kwargs): | ||
result = super().visit(node, **kwargs) | ||
if isinstance(result, IteratorTracer): | ||
if isinstance(result, Tracer): | ||
assert isinstance(node, (ir.Sym, ir.Expr)) | ||
|
||
self.shift_recorder.register_node(node) | ||
result = IteratorArgTracer( | ||
result = ArgTracer( | ||
arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder) | ||
) | ||
return result | ||
|
@@ -305,10 +310,10 @@ def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable: | |
def fun(*args): | ||
new_args = [] | ||
for param, arg in zip(node.params, args, strict=True): | ||
if isinstance(arg, IteratorTracer): | ||
if isinstance(arg, Tracer): | ||
self.shift_recorder.register_node(param) | ||
new_args.append( | ||
IteratorArgTracer( | ||
ArgTracer( | ||
arg=param, | ||
shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder), | ||
) | ||
|
@@ -322,46 +327,49 @@ def fun(*args): | |
|
||
return fun | ||
|
||
def visit_StencilClosure(self, node: ir.StencilClosure): | ||
tracers = [] | ||
for inp in node.inputs: | ||
self.shift_recorder.register_node(inp) | ||
tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder)) | ||
|
||
result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) | ||
assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) | ||
return node | ||
|
||
@classmethod | ||
def apply( | ||
cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False | ||
) -> ( | ||
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] | ||
def trace_stencil( | ||
cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False | ||
): | ||
# If we get a lambda we can deduce the number of arguments. | ||
if isinstance(stencil, ir.Lambda): | ||
assert num_args is None or num_args == len(stencil.params) | ||
num_args = len(stencil.params) | ||
if not isinstance(num_args, int): | ||
raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") | ||
assert isinstance(num_args, int) | ||
|
||
args = [im.ref(f"__arg{i}") for i in range(num_args)] | ||
|
||
old_recursionlimit = sys.getrecursionlimit() | ||
sys.setrecursionlimit(100000000) | ||
|
||
instance = cls() | ||
instance.visit(node) | ||
|
||
# initialize shift recorder & context with all built-ins and the iterator argument tracers | ||
ctx: dict[str, Any] = {**_START_CTX} | ||
for arg in args: | ||
instance.shift_recorder.register_node(arg) | ||
ctx[arg.id] = ArgTracer(arg=arg, shift_recorder=instance.shift_recorder) | ||
|
||
# actually trace stencil | ||
instance.visit(im.call(stencil)(*args), ctx=ctx) | ||
|
||
sys.setrecursionlimit(old_recursionlimit) | ||
|
||
recorded_shifts = instance.shift_recorder.recorded_shifts | ||
|
||
param_shifts = [] | ||
for arg in args: | ||
param_shifts.append(recorded_shifts[id(arg)]) | ||
|
||
if save_to_annex: | ||
_save_to_annex(node, recorded_shifts) | ||
_save_to_annex(stencil, recorded_shifts) | ||
|
||
if __debug__: | ||
ValidateRecordedShiftsAnnex().visit(node) | ||
return param_shifts | ||
|
||
if inputs_only: | ||
assert isinstance(node, ir.StencilClosure) | ||
inputs_shifts = {} | ||
for inp in node.inputs: | ||
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)] | ||
return inputs_shifts | ||
|
||
return recorded_shifts | ||
trace_stencil = TraceShifts.trace_stencil | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it also a classmethod? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To allow subclassing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you just make this up or do you have a use-case in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To stick to your wording: I made this up. I can change it if you like, I can also make up unlikely cases where this might be useful. |
||
|
||
|
||
def _save_to_annex( | ||
|
@@ -370,3 +378,6 @@ def _save_to_annex( | |
for child_node in node.pre_walk_values(): | ||
if id(child_node) in recorded_shifts: | ||
child_node.annex.recorded_shifts = recorded_shifts[id(child_node)] | ||
|
||
if __debug__: | ||
ValidateRecordedShiftsAnnex().visit(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's special with lambdas? what else can it be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lambdas we can deduce the number of arguments and the
num_args
parameter is optional. This is essentially a convenience feature for testing. I've added a comment.