Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Lowering foast/past to GTIR #1569

Merged
merged 33 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
aeff778
nothing
havogt Jun 30, 2024
6919596
basic opbinary
havogt Jun 30, 2024
87883a6
all tests updated
havogt Jul 3, 2024
01a6e07
test cleanup
havogt Jul 5, 2024
b949ebd
add to_gtir path to past_to_itir
havogt Jul 5, 2024
be21b78
update past_to_gtir_test
havogt Jul 10, 2024
be00112
gtir_embedded backend
havogt Jul 10, 2024
5230aeb
Merge remote-tracking branch 'upstream/main' into gtir_lowering
havogt Jul 11, 2024
96efba2
Merge branch 'main' into gtir_lowering
havogt Aug 2, 2024
5d0d16e
add broadcast, max_over, min_over
havogt Aug 5, 2024
bca403f
math builtins
havogt Aug 5, 2024
ce8ea56
cleanup merge conflict
havogt Aug 5, 2024
17f891d
add where lowering
havogt Aug 6, 2024
c77380a
cond cases
havogt Aug 7, 2024
3b4c4a2
implement astype lowering
havogt Aug 9, 2024
ab42da2
fix process_elements for itir lowering
havogt Aug 9, 2024
ef40997
fix pure scalar op lowering
havogt Aug 14, 2024
bff0278
Merge remote-tracking branch 'upstream/main' into gtir_lowering
havogt Aug 14, 2024
1bd0797
improve typing, unary
havogt Aug 14, 2024
f7792ff
as_offset
havogt Aug 14, 2024
0ea0f79
remove old license
havogt Aug 14, 2024
70dd170
disable gtir embedded
havogt Aug 14, 2024
99f86c6
cleanup use of dimension
havogt Aug 14, 2024
3458aee
cleanup
havogt Aug 15, 2024
d46f308
Apply suggestions from code review
havogt Aug 23, 2024
960e173
Merge remote-tracking branch 'upstream/main' into gtir_lowering
havogt Aug 23, 2024
cdc8963
fix import
havogt Aug 23, 2024
e17ef1c
address review comments
havogt Aug 23, 2024
789a215
undo unneeded change
havogt Aug 23, 2024
156b884
domain in itir maker
havogt Aug 23, 2024
9d2c034
TODO to FIXME
havogt Aug 23, 2024
efd3587
address review comments
havogt Aug 26, 2024
c94e1e9
add testcase for tuple in frozen namespace
havogt Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
385 changes: 385 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

# FIXME[#1582](havogt): remove after refactoring to GTIR

import dataclasses
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -332,7 +334,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
f"Call to object of type '{type(node.func.type).__name__}' not understood."
havogt marked this conversation as resolved.
Show resolved Hide resolved
)

def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = node.args[0], node.args[1].id
return lowering_utils.process_elements(
Expand All @@ -343,7 +345,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
obj.type,
)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
condition, true_value, false_value = node.args

lowered_condition = self.visit(condition, **kwargs)
Expand Down
29 changes: 16 additions & 13 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import Any, Callable, TypeVar

from gt4py.eve import utils as eve_utils
Expand Down Expand Up @@ -99,7 +100,7 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall:
# TODO(tehrengruber): The code quality of this function is poor. We should rewrite it.
def process_elements(
process_func: Callable[..., itir.Expr],
objs: itir.Expr | list[itir.Expr],
objs: itir.Expr | Iterable[itir.Expr],
current_el_type: ts.TypeSpec,
) -> itir.FunCall:
"""
Expand All @@ -113,34 +114,36 @@ def process_elements(
are not used and thus not relevant.
"""
if isinstance(objs, itir.Expr):
objs = [objs]
objs = (objs,)

_current_el_exprs = [
im.ref(f"__val_{eve_utils.content_hash(obj)}") for i, obj in enumerate(objs)
]
body = _process_elements_impl(process_func, _current_el_exprs, current_el_type)

return im.let(*((f"__val_{eve_utils.content_hash(obj)}", obj) for i, obj in enumerate(objs)))( # type: ignore[arg-type] # mypy not smart enough
body
let_ids = tuple(f"__val_{eve_utils.content_hash(obj)}" for obj in objs)
body = _process_elements_impl(
process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type
)

return im.let(*(zip(let_ids, objs, strict=True)))(body)


T = TypeVar("T", bound=itir.Expr, covariant=True)


def _process_elements_impl(
process_func: Callable[..., itir.Expr], _current_el_exprs: list[T], current_el_type: ts.TypeSpec
process_func: Callable[..., itir.Expr],
_current_el_exprs: Iterable[T],
current_el_type: ts.TypeSpec,
) -> itir.Expr:
if isinstance(current_el_type, ts.TupleType):
result = im.make_tuple(
*[
*(
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
_process_elements_impl(
process_func,
[im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs],
tuple(
im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs
),
current_el_type.types[i],
)
for i in range(len(current_el_type.types))
]
)
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
Expand Down
72 changes: 59 additions & 13 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

@dataclasses.dataclass(frozen=True)
class PastToItir(workflow.ChainableWorkflowMixin):
to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR

def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars)
offsets_and_dimensions = transform_utils._filter_closure_vars_by_type(
Expand All @@ -48,7 +50,10 @@ def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]

itir_program = ProgramLowering.apply(
inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type
inp.past_node,
function_definitions=lowered_funcs,
grid_type=grid_type,
to_gtir=self.to_gtir,
)

if config.DEBUG or "debug" in inp.kwargs:
Expand Down Expand Up @@ -108,6 +113,7 @@ def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]:
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")


@dataclasses.dataclass
class ProgramLowering(
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
Expand Down Expand Up @@ -144,6 +150,9 @@ class ProgramLowering(
[Sym(id=SymbolName('inp')), Sym(id=SymbolName('out')), Sym(id=SymbolName('__inp_size_0')), Sym(id=SymbolName('__out_size_0'))]
"""

grid_type: common.GridType
to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR

# TODO(tehrengruber): enable doctests again. For unknown / obscure reasons
# the above doctest fails when executed using `pytest --doctest-modules`.

Expand All @@ -153,11 +162,11 @@ def apply(
node: past.Program,
function_definitions: list[itir.FunctionDefinition],
grid_type: common.GridType,
to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR
) -> itir.FencilDefinition:
return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions)

def __init__(self, grid_type: common.GridType):
self.grid_type = grid_type
return cls(grid_type=grid_type, to_gtir=to_gtir).visit(
node, function_definitions=function_definitions
)

def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
"""Generate symbols for each field param and dimension."""
Expand Down Expand Up @@ -186,7 +195,7 @@ def visit_Program(
*,
function_definitions: list[itir.FunctionDefinition],
**kwargs: Any,
) -> itir.FencilDefinition:
) -> itir.FencilDefinition | itir.Program:
# The ITIR does not support dynamically getting the size of a field. As
# a workaround we add additional arguments to the fencil definition
# containing the size of all fields. The caller of a program is (e.g.
Expand All @@ -197,15 +206,50 @@ def visit_Program(
if any("domain" not in body_entry.kwargs for body_entry in node.body):
params = params + self._gen_size_params_from_program(node)

closures: list[itir.StencilClosure] = []
for stmt in node.body:
closures.append(self._visit_stencil_call(stmt, **kwargs))
if self.to_gtir:
set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body]
return itir.Program(
id=node.id,
function_definitions=function_definitions,
params=params,
declarations=[],
body=set_ats,
)
else:
closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body]
return itir.FencilDefinition(
id=node.id,
function_definitions=function_definitions,
params=params,
closures=closures,
)

def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt:
assert isinstance(node.kwargs["out"].type, ts.TypeSpec)
assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType)

node_kwargs = {**node.kwargs}
domain = node_kwargs.pop("domain", None)
output, lowered_domain = self._visit_stencil_call_out_arg(
node_kwargs.pop("out"), domain, **kwargs
)

assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType))

return itir.FencilDefinition(
id=node.id, function_definitions=function_definitions, params=params, closures=closures
args, node_kwargs = type_info.canonicalize_arguments(
node.func.type, node.args, node_kwargs, use_signature_ordering=True
)

lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs)

return itir.SetAt(
expr=im.call(node.func.id)(*lowered_args, *lowered_kwargs.values()),
domain=lowered_domain,
target=output,
)

def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure:
# FIXME[#1582](havogt): remove after refactoring to GTIR
def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure:
havogt marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(node.kwargs["out"].type, ts.TypeSpec)
assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType)

Expand Down Expand Up @@ -241,7 +285,9 @@ def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClo
else:
# field operators return a tuple of iterators, deref element-wise
stencil_body = lowering_utils.process_elements(
im.deref, im.call(node.func.id)(*stencil_args), node.func.type.definition.returns
im.deref,
im.call(node.func.id)(*stencil_args),
node.func.type.definition.returns,
)

return itir.StencilClosure(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def field_from_typespec(
(NumPyArrayField(... dtype=int32...), NumPyArrayField(... dtype=float32...))
"""

@utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple)
@utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=tuple)
def impl(type_: ts.ScalarType) -> common.MutableField:
res = common._field(
xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)),
Expand Down
37 changes: 29 additions & 8 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import typing
from typing import Callable, Iterable, Optional, Union
from typing import Callable, Optional, Union

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Dict, Tuple
Expand Down Expand Up @@ -81,6 +81,8 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti
return ref(literal_or_expr)
elif core_defs.is_scalar_type(literal_or_expr):
return literal_from_value(literal_or_expr)
elif literal_or_expr is None:
return itir.NoneLiteral()
assert isinstance(literal_or_expr, itir.Expr)
return literal_or_expr

Expand Down Expand Up @@ -241,10 +243,15 @@ def tuple_get(index: str | int, tuple_expr):


def if_(cond, true_val, false_val):
"""Create a not_ FunCall, shorthand for ``call("if_")(expr)``."""
"""Create a if_ FunCall, shorthand for ``call("if_")(expr)``."""
return call("if_")(cond, true_val, false_val)


def cond(cond, true_val, false_val):
"""Create a cond FunCall, shorthand for ``call("cond")(expr)``."""
return call("cond")(cond, true_val, false_val)


def lift(expr):
"""Create a lift FunCall, shorthand for ``call(call("lift")(expr))``."""
return call(call("lift")(expr))
Expand All @@ -266,7 +273,7 @@ class let:
def __init__(self, var: str | itir.Sym, init_form: itir.Expr | str): ...

@typing.overload
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr | str]]): ...
def __init__(self, *args: tuple[str | itir.Sym, itir.Expr | str]): ...

def __init__(self, *args):
if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
Expand Down Expand Up @@ -356,6 +363,20 @@ def lifted_neighbors(offset, it) -> itir.Expr:
return lift(lambda_("it")(neighbors(offset, "it")))(it)


def as_fieldop_neighbors(
offset: str | itir.OffsetLiteral, it: str | itir.Expr, domain: Optional[itir.FunCall] = None
) -> itir.Expr:
"""
Create a fieldop for neighbors call.

Examples
--------
>>> str(as_fieldop_neighbors("off", "a"))
'(⇑(λ(it) → neighbors(offₒ, it)))(a)'
"""
return as_fieldop(lambda_("it")(neighbors(offset, "it")), domain)(it)


def promote_to_const_iterator(expr: str | itir.Expr) -> itir.Expr:
"""
Create a lifted nullary lambda that captures `expr`.
Expand Down Expand Up @@ -394,11 +415,6 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
return _impl


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))


def domain(
grid_type: Union[common.GridType, str],
ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]],
Expand Down Expand Up @@ -485,3 +501,8 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its)

return _impl


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.
new_args.append(arg)

if bound_vars:
return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough
return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args)))
return None

def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]:
Expand Down Expand Up @@ -298,7 +298,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]:
inner_vars[arg_sym] = arg
if outer_vars:
return self.fp_transform(
im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough
im.let(*outer_vars.items())(
self.fp_transform(im.let(*inner_vars.items())(original_inner_expr))
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
im.call(node.fun)(*new_args), eligible_params=eligible_params
)
# TODO(tehrengruber): propagate let outwards
return im.let(*bound_scalars.items())(new_node) # type: ignore[arg-type] # mypy not smart enough
return im.let(*bound_scalars.items())(new_node)

return node
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/propagate_deref.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def visit_FunCall(self, node: ir.FunCall):
if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]):
fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let
args: list[ir.Expr] = node.args[0].args
node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough
node = im.let(*zip(fun.params, args, strict=True))(im.deref(fun.expr))
elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"):
cond, true_branch, false_branch = node.args[0].args
return im.if_(cond, im.deref(true_branch), im.deref(false_branch))
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/otf/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

@dataclasses.dataclass(frozen=True)
class ProgramCall:
"""Iterator IR representaion of a program together with arguments to be passed to it."""
"""ITIR/GTIR representation of a program together with arguments to be passed to it."""

program: itir.FencilDefinition
program: itir.FencilDefinition | itir.Program
args: tuple[Any, ...]
kwargs: dict[str, Any]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def __call__(
self, inp: stages.ProgramCall
) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]:
"""Generate GTFN C++ code from the ITIR definition."""
program: itir.FencilDefinition = inp.program
program = inp.program
assert isinstance(program, itir.FencilDefinition)

# handle regular parameters and arguments of the program (i.e. what the user defined in
# the program)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __call__(
self, inp: stages.ProgramCall
) -> stages.ProgramSource[languages.SDFG, LanguageSettings]:
"""Generate DaCe SDFG file from the ITIR definition."""
program: itir.FencilDefinition = inp.program
program = inp.program
assert isinstance(program, itir.FencilDefinition)
arg_types = [tt.from_value(arg) for arg in inp.args]

sdfg = self.generate_sdfg(
Expand Down
Loading
Loading