Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: GTIR-to-SDFG lowering of neighbors and reduce #1597

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
207 commits
Select commit Hold shift + click to select a range
b36fcb3
Skeleton for ITIR translation
edopao Apr 18, 2024
2020182
Minor edit
edopao Apr 18, 2024
5c6b6ba
Use Python callstack as a context stack for the ITIR visitor
edopao Apr 19, 2024
60e1c69
Format error
edopao Apr 19, 2024
073a0a4
Refactor tasklet codegen
edopao Apr 19, 2024
50be68f
Code refactoring
edopao Apr 19, 2024
4e2dc15
Add domain to field operator
edopao Apr 22, 2024
ea9da35
Minor edit
edopao Apr 22, 2024
daf7827
Remove hard-coded field shape
edopao Apr 23, 2024
9672b3b
Remove hard-coded target domain
edopao Apr 23, 2024
b6326b8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 23, 2024
26f3790
Refactoring
edopao Apr 24, 2024
1efffa7
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 24, 2024
f99fa84
Fix formatting
edopao Apr 24, 2024
d6e1088
More refactoring
edopao Apr 24, 2024
9854497
Minor edit
edopao Apr 24, 2024
37d83d7
Fix formatting
edopao Apr 24, 2024
29986ef
Use callable to build taskgraph
edopao Apr 29, 2024
390f3b4
Add draft of select operator
edopao Apr 29, 2024
de27419
Remove node mapping
edopao Apr 30, 2024
cd900f5
Remove node mapping (fix + test case)
edopao Apr 30, 2024
326cbb5
Add test case for inlined mathematic builtins
edopao Apr 30, 2024
a10b614
Go full functional (remove SDFGState member var)
edopao Apr 30, 2024
aef4265
Minor edit
edopao Apr 30, 2024
9e67dfe
Minor edit (1)
edopao Apr 30, 2024
4b4109e
Fix state handling
edopao Apr 30, 2024
495fd0a
Edit comments based on review
edopao Apr 30, 2024
0085194
Add test case for nested select
edopao Apr 30, 2024
41e2a44
Separate builtin translation from driver logic
edopao May 1, 2024
7148c5f
Improve code comments
edopao May 2, 2024
452399d
Avoid inheritance: pass dataflow builder as arg to builtin translator
edopao May 3, 2024
e404226
Codestyle review changes
edopao May 3, 2024
bb0dfac
Remove circular dependency for builtin translators
edopao May 6, 2024
412cd5d
Fix formatting
edopao May 6, 2024
651de5c
Minor edit
edopao May 6, 2024
dcf3eab
Add support to translate each builtin call to a tasklet node
edopao May 6, 2024
7e6909e
Resolve dace warnings
edopao May 7, 2024
2b07cc5
Remove bultin translator for domain expressions
edopao May 7, 2024
2370fa6
Remove bultin translator for domain expressions (1)
edopao May 7, 2024
8e801df
Refactor
edopao May 7, 2024
812a6e5
Minor edit
edopao May 7, 2024
1d0b50b
Extract ITIR visitor to separate class
edopao May 7, 2024
97a1d22
Code refactoring
edopao May 7, 2024
a30cc7d
Fix formatting
edopao May 7, 2024
f595b01
Add IteratorExpr type
edopao May 10, 2024
a6bcb6c
Indirection shift implemented as tasklet node
edopao May 10, 2024
738da27
Add ConnectivityExpr type
edopao May 13, 2024
e5494d8
Remove ConnectivityExpr type, use ValueExpr instead
edopao May 13, 2024
e9455e3
Changes in preparation for shift builtin
edopao May 13, 2024
cbf55de
Refactoring
edopao May 13, 2024
9d5b1ed
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 13, 2024
801704b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 13, 2024
c45c417
Add support for programs without computation (pure memlets)
edopao May 13, 2024
3f26d91
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 13, 2024
f173244
Merge remote-tracking branch 'origin/main' into dace-fieldview-shifts
edopao May 13, 2024
d67518a
Fix test
edopao May 13, 2024
783542f
Fix for chain of shift expressions shift(V2E(E2V(i_edge, x), y))(edges)
edopao May 14, 2024
1fa9de4
Support for multi-dimensional shift
edopao May 14, 2024
96338c2
Fix typo
edopao May 14, 2024
57e369f
Add support for cartesian shift with dynamic offset
edopao May 15, 2024
ec4714c
Add support for unstructured shift with dynamic offset
edopao May 15, 2024
46cb6c6
Code refactoring in test file
edopao May 15, 2024
c20a94d
Typo
edopao May 15, 2024
d1f7432
Code cleanup
edopao May 15, 2024
c4385c1
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 16, 2024
ed16fd4
Import updates from branch dace-fieldview-shifts
edopao May 16, 2024
9f7176f
Review comments
edopao May 16, 2024
4f40f42
Merge branch 'dace-fieldview' into dace-fieldview-shifts
edopao May 16, 2024
932db7c
Avoid tasklet-to-tasklet edge connections
edopao May 16, 2024
46febb0
Avoid tasklet-to-tasklet edge connections
edopao May 16, 2024
949bad7
Add support for in-out field parameters
edopao May 16, 2024
8890f95
Refactoring: import modules, not symbols
edopao May 17, 2024
87b71a6
Minor edit
edopao May 17, 2024
665a609
Remove internal package for builtin translators
edopao May 17, 2024
82fdf64
Add wrapper function to build SDFG
edopao May 17, 2024
e4718b0
Merge pull request #4 from edopao/dace-fieldview-refactor_imports
edopao May 17, 2024
47fcabe
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 17, 2024
51aaf0f
Add fieldview flavor of all test cases
edopao May 17, 2024
6ccecf1
Code changes imported from branch dace-fieldview-shifts
edopao May 17, 2024
e66b960
Code comments updated
edopao May 17, 2024
7f89a16
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 17, 2024
4c190bd
Remove support for inlined chained shift
edopao May 21, 2024
6052de2
Add support for neighbors builtin
edopao May 21, 2024
7300864
Add support for reduce builtin
edopao May 22, 2024
55adbd5
Refactoring
edopao May 23, 2024
ad21dc4
Add support for both inlined and fieldview neighbor reduction
edopao May 23, 2024
bb9123b
Minor edit
edopao May 23, 2024
0025d77
Code refactoring
edopao May 23, 2024
9926d7d
Add support for skip values ONLY for inlined GTIR
edopao May 23, 2024
172f19e
Masked array implementation based on connectivity table
edopao May 27, 2024
b1f4a47
Merge 2 different implementations of reduce
edopao May 27, 2024
63e6e92
Add support for reduce lambda function
edopao May 28, 2024
107e295
Add support for neighbors masked array returned by select statements
edopao May 29, 2024
3c71efa
Import changes from neighbors branch
edopao May 29, 2024
e369cac
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 29, 2024
d0bd277
Import changes from neighbors branch
edopao May 29, 2024
afb5ed1
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao May 29, 2024
2f75cfb
Add debuginfo for ir.Program and ir.Stmt nodes
edopao May 29, 2024
695db7c
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 29, 2024
074f0b2
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao May 29, 2024
085f307
Fix error in debuginfo
edopao May 29, 2024
f19960b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 29, 2024
841040e
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 29, 2024
b3df358
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao May 29, 2024
dc1434c
Fix error in debuginfo (1)
edopao May 29, 2024
eacde66
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao May 29, 2024
138a33c
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao May 29, 2024
3769fb5
Remove nested SDFG for neighbors builtin
edopao Jun 14, 2024
b1b5887
Remove masked array for skip values, rely on identity value
edopao Jun 26, 2024
a5b0f41
import changes from neighbors branch
edopao Jun 28, 2024
f7ac3d8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jun 28, 2024
01ff262
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jun 28, 2024
c61e796
import changes from neighbors branch
edopao Jun 28, 2024
5a457b2
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jun 28, 2024
600f23f
Merge remote-tracking branch 'origin/main' into dace-fieldview-neighbors
edopao Jul 3, 2024
f1a3c18
Minor edit
edopao Jul 3, 2024
9318011
Import changes from branch dace-fieldview-neighbors
edopao Jul 4, 2024
11efdeb
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 4, 2024
25b9048
Support field with start offset
edopao Jul 4, 2024
2dc6f97
Merge branch 'dace-fieldview' into dace-fieldview-shifts
edopao Jul 4, 2024
f6e5b7c
Add test coverage for temporary with start offset (cartesian shift)
edopao Jul 4, 2024
1b30016
Merge branch 'dace-fieldview-shifts' into dace-fieldview-neighbors
edopao Jul 4, 2024
d7312fa
Support field with start offset
edopao Jul 4, 2024
628c18b
Merge branch 'dace-fieldview' into dace-fieldview-shifts
edopao Jul 4, 2024
19a629e
Merge branch 'dace-fieldview-shifts' into dace-fieldview-neighbors
edopao Jul 4, 2024
c4f2738
Test IR updated for literal operand
edopao Jul 4, 2024
0fd0b65
Add test coverage to previous commit
edopao Jul 4, 2024
38d2720
Refactor PrimitiveTranslator interface
edopao Jul 4, 2024
e855ef9
Fix formatting
edopao Jul 5, 2024
4cff071
Fix for domain horzontal/vertical dims
edopao Jul 5, 2024
f642e85
Fix for type inference on single value expression
edopao Jul 5, 2024
f216a36
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 5, 2024
3a4094b
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 5, 2024
c41b657
Move fieldop map creation outside LambdaToTasklet
edopao Jul 5, 2024
defb55d
Import changes from dace-fieldview-neighbors
edopao Jul 5, 2024
fc9661c
Import changes from dace-fieldview-shifts
edopao Jul 5, 2024
e424d4e
Minor edit
edopao Jul 5, 2024
7ef1d56
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 5, 2024
b3c1aba
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 5, 2024
3a3685a
Code-style review comments
edopao Jul 9, 2024
3ae3030
Bugfix for temporary arrays with offset (shifts)
edopao Jul 9, 2024
f3a3dfd
Move PythonCodegen to separate module
edopao Jul 9, 2024
4dd665b
Remove sdfg_builder module
edopao Jul 9, 2024
1da297a
Remove constructor for PrimitiveTranslator derived classes
edopao Jul 10, 2024
25fb0a1
Extract subfuction from AsFieldOp translator
edopao Jul 10, 2024
66c5fcd
Address review comments
edopao Jul 10, 2024
d5abad4
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 10, 2024
1df1bc3
Apply convention for map variables
edopao Jul 10, 2024
2032b60
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 10, 2024
eba7469
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 10, 2024
72e0a9c
Fix circular import of modules needed for type-checking
edopao Jul 11, 2024
d6e4a3e
Fix direct symbol import
edopao Jul 11, 2024
c367e78
Change builtin translators to pure functions (no state)
edopao Jul 11, 2024
37c5021
Improve utility.get_domain based on review comments
edopao Jul 11, 2024
729a5bc
Minor edit
edopao Jul 11, 2024
80ad310
Rename module
edopao Jul 11, 2024
7f72794
Import changes from dace-fieldview-neighbors
edopao Jul 11, 2024
abf3918
Import changes from dace-fieldview-shifts
edopao Jul 11, 2024
a6d31fb
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 11, 2024
14c397d
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 11, 2024
6394b76
Initial implementation of let-lambdas
edopao Jul 12, 2024
a8a1bb7
Generate unique names for map/tasklet nodes
edopao Jul 12, 2024
cda8f7f
Import changes from branch dace-fieldview-let_lambdas
edopao Jul 12, 2024
9301dbe
Import changes from branch dace-fieldview-neighbors
edopao Jul 12, 2024
7f60cfe
Import changes from branch dace-fieldview-shifts
edopao Jul 12, 2024
699a88b
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 12, 2024
452497d
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 12, 2024
1c6e3a2
Merge remote-tracking branch 'origin/dace-fieldview-neighbors' into d…
edopao Jul 12, 2024
b3131db
Avoid direct import of symbols from module
edopao Jul 12, 2024
130c877
Address review comments
edopao Jul 12, 2024
7fbd7e1
Merge remote-tracking branch 'origin/dace-fieldview' into dace-fieldv…
edopao Jul 12, 2024
a43bc07
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 12, 2024
6d5b31b
Merge remote-tracking branch 'origin/dace-fieldview-neighbors' into d…
edopao Jul 12, 2024
714e086
Use canonical mar veriable for neighbors map
edopao Jul 12, 2024
605ecd2
Merge remote-tracking branch 'origin/dace-fieldview-neighbors' into d…
edopao Jul 12, 2024
7a0f0ad
Extend let-lambda test coverage
edopao Jul 12, 2024
a6b191c
Merge remote-tracking branch 'origin/main' into dace-fieldview-shifts
edopao Jul 12, 2024
ba99810
Merge remote-tracking branch 'origin/dace-fieldview-shifts' into dace…
edopao Jul 12, 2024
0874785
Merge remote-tracking branch 'origin/dace-fieldview-neighbors' into d…
edopao Jul 12, 2024
3534c06
Integrate type system into dace backend
edopao Jul 23, 2024
84e28dd
WIP neighbors type inference
edopao Jul 23, 2024
9e5fd2e
Remove menighbors-reduce from baseline
edopao Jul 24, 2024
3f8cd6d
Apply ir_maker op_as_fieldop to legacy tests
edopao Jul 24, 2024
f3ad657
Import changes from shifts branch
edopao Jul 24, 2024
a21c5a1
Merge remote-tracking branch 'origin/main' into dace-fieldview-let_la…
edopao Jul 25, 2024
9c10134
Minor edit
edopao Jul 25, 2024
0b05423
Review comments
edopao Jul 25, 2024
3eced05
Review comments + new testcase
edopao Jul 26, 2024
9ae57d1
Working solution
edopao Jul 26, 2024
94f1529
Import changes from let-lambdas branch
edopao Jul 25, 2024
015b924
First working prototype for neighbors-reduce
edopao Jul 26, 2024
4cc1c82
Merge remote-tracking branch 'origin/dace-fieldview-neighbors-let_red…
edopao Jul 29, 2024
4ac22e9
Utility for common pattern matching
edopao Jul 29, 2024
fd737c8
Merge fixes
edopao Jul 29, 2024
676d49b
Minor edit
edopao Jul 29, 2024
0f904ae
Merge remote-tracking branch 'origin/dace-fieldview-let_lambdas-rebas…
edopao Jul 29, 2024
88e4574
Remove assert on KeyError
edopao Jul 29, 2024
7e414f2
Merge remote-tracking branch 'origin/dace-fieldview-let_lambdas-rebas…
edopao Jul 29, 2024
37afe2d
Add code comment
edopao Jul 29, 2024
8472719
Merge remote-tracking branch 'origin/dace-fieldview-let_lambdas-rebas…
edopao Jul 29, 2024
6183305
Use let-symbol for reduce identity
edopao Jul 29, 2024
924e413
Merge remote-tracking branch 'origin/main' into dace-fieldview-neighb…
edopao Jul 29, 2024
ff8064a
Minor edit
edopao Jul 29, 2024
ab68406
Edit code comments
edopao Jul 29, 2024
42e8bf7
Update code comments
edopao Jul 30, 2024
2c7b6d3
Update testcases
edopao Jul 30, 2024
41438e7
Split lowering of Literal and SymRef as 2 different primitive transla…
edopao Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `reduce(λ(...) → ...)(...)`."""
return (
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "reduce"
)


def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `shift(λ(...) → ...)(...)`."""
return (
Expand Down
15 changes: 4 additions & 11 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gt4py.eve import NodeTranslator, traits
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import inline_lambdas


Expand All @@ -29,14 +30,6 @@ def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]:
)


def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:
return (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.FunCall)
and node.fun.fun == ir.SymRef(id="reduce")
)


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Expand Down Expand Up @@ -71,7 +64,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:

def visit_FunCall(self, node: ir.FunCall, **kwargs):
node = self.generic_visit(node)
if _is_map(node) or _is_reduce(node):
if _is_map(node) or cpm.is_applied_reduce(node):
if any(_is_map(arg) for arg in node.args):
first_param = (
0 if _is_map(node) else 1
Expand All @@ -83,7 +76,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
inlined_args = []
new_params = []
new_args = []
if _is_reduce(node):
if cpm.is_applied_reduce(node):
# param corresponding to reduce acc
inlined_args.append(ir.SymRef(id=outer_op.params[0].id))
new_params.append(outer_op.params[0])
Expand Down Expand Up @@ -119,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
)
else: # _is_reduce(node)
else: # is_applied_reduce(node)
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]),
args=new_args,
Expand Down
9 changes: 3 additions & 6 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


Expand Down Expand Up @@ -60,16 +61,12 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]:
return [_get_partial_offset_tag(arg) for arg in _get_neighbors_args(reduce_args)]


def _is_reduce(node: itir.FunCall) -> TypeGuard[itir.FunCall]:
return isinstance(node.fun, itir.FunCall) and node.fun.fun == itir.SymRef(id="reduce")


def _get_connectivity(
applied_reduce_node: itir.FunCall,
offset_provider: dict[str, common.Dimension | common.Connectivity],
) -> common.Connectivity:
"""Return single connectivity that is compatible with the arguments of the reduce."""
if not _is_reduce(applied_reduce_node):
if not cpm.is_applied_reduce(applied_reduce_node):
raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.")

connectivities: list[common.Connectivity] = []
Expand Down Expand Up @@ -158,6 +155,6 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr:

def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Expr:
node = self.generic_visit(node, **kwargs)
if _is_reduce(node):
if cpm.is_applied_reduce(node):
return self._visit_reduce(node, **kwargs)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.iterator.type_system import type_specifications as gtir_ts
from gt4py.next.program_processors.runners.dace_fieldview import (
gtir_python_codegen,
gtir_to_tasklet,
Expand All @@ -38,7 +38,7 @@


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
LetSymbol: TypeAlias = tuple[str, ts.FieldType | ts.ScalarType]
LetSymbol: TypeAlias = tuple[gtir.Literal | gtir.SymRef, ts.FieldType | ts.ScalarType]
TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]


Expand All @@ -62,7 +62,9 @@ def __call__(
sdfg: The SDFG where the primitive subgraph should be instantiated
state: The SDFG state where the result of the primitive function should be made available
sdfg_builder: The object responsible for visiting child nodes of the primitive node.
let_symbols: Mapping of symbols (i.e. lambda parameters) to known temporary fields.
let_symbols: Mapping of symbols (i.e. lambda parameters and/or local constants
like the identity value in a reduction context) to temporary fields
or symbolic expressions.

Returns:
A list of data access nodes and the associated GT4Py data type, which provide
Expand Down Expand Up @@ -105,11 +107,12 @@ def _parse_arg_expr(
)
for dim, _, _ in domain
}
return gtir_to_tasklet.IteratorExpr(
data_node,
arg_type.dims,
indices,
dims = arg_type.dims + (
# we add an extra anonymous dimension in the iterator definition to enable
# dereferencing elements in `ListType`
[gtx_common.Dimension("")] if isinstance(arg_type.dtype, gtir_ts.ListType) else []
)
return gtir_to_tasklet.IteratorExpr(data_node, dims, indices)


def _create_temporary_field(
Expand All @@ -134,27 +137,20 @@ def _create_temporary_field(
field_offset = [-lb for lb in domain_lbs]

if isinstance(output_desc, dace.data.Array):
# extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`)
assert isinstance(output_field_type, ts.FieldType)
if isinstance(node_type.dtype, itir_ts.ListType):
raise NotImplementedError
else:
field_dtype = node_type.dtype
assert output_field_type.dtype == field_dtype
field_dims.extend(output_field_type.dims)
assert isinstance(node_type.dtype, gtir_ts.ListType)
field_dtype = node_type.dtype.element_type
# extend the result arrays with the local dimensions added by the field operator (e.g. `neighbors`)
field_shape.extend(output_desc.shape)
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
else:
assert isinstance(output_desc, dace.data.Scalar)
assert isinstance(output_field_type, ts.ScalarType)
field_dtype = node_type.dtype
assert output_field_type == field_dtype

# allocate local temporary storage for the result field
temp_name, _ = sdfg.add_temp_transient(
field_shape, dace_fieldview_util.as_dace_type(field_dtype), offset=field_offset
)
field_node = state.add_access(temp_name)
field_type = ts.FieldType(field_dims, field_dtype)
field_type = ts.FieldType(field_dims, node_type.dtype)

return field_node, field_type

Expand All @@ -169,6 +165,7 @@ def translate_as_field_op(
"""Generates the dataflow subgraph for the `as_fieldop` builtin function."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)

fun_node = node.fun
assert len(fun_node.args) == 2
Expand All @@ -182,13 +179,40 @@ def translate_as_field_op(
domain = dace_fieldview_util.get_domain(domain_expr)
assert isinstance(node.type, ts.FieldType)

reduce_identity: Optional[gtir_to_tasklet.SymbolExpr] = None
if cpm.is_applied_reduce(stencil_expr.expr):
# 'reduce' is a reserved keyword of the DSL and we will never find a user-defined symbol
# with this name. Since 'reduce' will never collide with a user-defined symbol, it is safe
# to use it internally to store the reduce identity value as a let-symbol.
if "reduce" in let_symbols:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("nested reductions not supported.")
edopao marked this conversation as resolved.
Show resolved Hide resolved

# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr)

# we store the reduce identity value as a constant let-symbol
let_symbols = let_symbols | {
"reduce": (
gtir.Literal(value=str(reduce_identity.value), type=stencil_expr.expr.type),
reduce_identity.dtype,
)
}

elif "reduce" in let_symbols:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# a parent node is a reduction node, so we are visiting the current node in the context of a reduction
reduce_symbol, _ = let_symbols["reduce"]
assert isinstance(reduce_symbol, gtir.Literal)
reduce_identity = gtir_to_tasklet.SymbolExpr(
reduce_symbol.value, dace_fieldview_util.as_dace_type(reduce_symbol.type)
)

# first visit the list of arguments and build a symbol map
stencil_args = [
_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args
]
edopao marked this conversation as resolved.
Show resolved Hide resolved

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder)
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity)
input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
assert isinstance(output_expr, gtir_to_tasklet.ValueExpr)
output_desc = output_expr.node.desc(sdfg)
Expand All @@ -205,7 +229,7 @@ def translate_as_field_op(

# allocate local temporary storage for the result field
field_node, field_type = _create_temporary_field(
sdfg, state, domain, node.type, output_desc, output_expr.field_type
sdfg, state, domain, node.type, output_desc, output_expr.dtype
)

# assume tasklet with single output
Expand Down Expand Up @@ -327,6 +351,54 @@ def translate_cond(
return output_nodes


def _get_symbolic_value(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
symbolic_expr: dace.symbolic.SymExpr,
scalar_type: ts.ScalarType,
temp_name: Optional[str] = None,
) -> dace.nodes.AccessNode:
tasklet_node = sdfg_builder.add_tasklet(
"get_value",
state,
{},
{"__out"},
f"__out = {symbolic_expr}",
)
temp_name, _ = sdfg.add_scalar(
f"__{temp_name or 'tmp'}",
dace_fieldview_util.as_dace_type(scalar_type),
find_new_name=True,
transient=True,
)
data_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
data_node,
None,
dace.Memlet(data=temp_name, subset="0"),
)
return data_node


def translate_literal(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.Literal` node."""
assert isinstance(node, gtir.Literal)

data_type = node.type
data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type)

return [(data_node, data_type)]


def translate_symbol_ref(
node: gtir.Node,
sdfg: dace.SDFG,
Expand All @@ -335,57 +407,33 @@ def translate_symbol_ref(
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, (gtir.Literal, gtir.SymRef))

data_type: ts.FieldType | ts.ScalarType
if isinstance(node, gtir.Literal):
sym_value = node.value
data_type = node.type
temp_name = "literal"
assert isinstance(node, gtir.SymRef)

sym_value = str(node.id)
if sym_value in let_symbols:
let_node, sym_type = let_symbols[sym_value]
if isinstance(let_node, gtir.Literal):
# this branch handles the case a let-symbol is mapped to some constant value
return sdfg_builder.visit(let_node)
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node needs to
# be created in the state where they are accessed.
sym_value = str(let_node.id)
else:
sym_value = str(node.id)
if sym_value in let_symbols:
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node is created
# everytime they are accessed. It is therefore possible that multiple access
# nodes are created in one state for the same data container. We rely
# on the simplify to remove duplicated access nodes.
sym_value, data_type = let_symbols[sym_value]
else:
data_type = sdfg_builder.get_symbol_type(sym_value)
temp_name = sym_value

if isinstance(data_type, ts.FieldType):
# add access node to current state
sym_node = state.add_access(sym_value)
sym_type = sdfg_builder.get_symbol_type(sym_value)

# Create new access node in current state. It is possible that multiple
# access nodes are created in one state for the same data container.
# We rely on the dace simplify pass to remove duplicated access nodes.
if isinstance(sym_type, ts.FieldType):
sym_node = state.add_access(sym_value)
else:
# scalar symbols are passed to the SDFG as symbols: build tasklet node
# to write the symbol to a scalar access node
tasklet_node = sdfg_builder.add_tasklet(
f"get_{temp_name}",
state,
{},
{"__out"},
f"__out = {sym_value}",
)
temp_name, _ = sdfg.add_scalar(
f"__{temp_name}",
dace_fieldview_util.as_dace_type(data_type),
find_new_name=True,
transient=True,
)
sym_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
sym_node,
None,
dace.Memlet(data=sym_node.data, subset="0"),
sym_node = _get_symbolic_value(
sdfg, state, sdfg_builder, sym_value, sym_type, temp_name=sym_value
)

return [(sym_node, data_type)]
return [(sym_node, sym_type)]


if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def visit_Lambda(
symbol, the parameter will shadow the previous symbol during traversal of the lambda expression.
"""
lambda_symbols = let_symbols | {
str(p.id): (temp_node.data, type_)
str(p.id): (gtir.SymRef(id=temp_node.data), type_)
for p, (temp_node, type_) in zip(node.params, args, strict=True)
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -404,7 +404,7 @@ def visit_Literal(
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
) -> list[gtir_builtin_translators.TemporaryData]:
return gtir_builtin_translators.translate_symbol_ref(
return gtir_builtin_translators.translate_literal(
node, sdfg, head_state, self, let_symbols={}
)

Expand Down
Loading
Loading