Skip to content

Commit

Permalink
feat[next][dace]: GTIR-to-SDFG lowering of let-lambdas (using NestedS…
Browse files Browse the repository at this point in the history
…DFG) (#1601)

This PR proposes an alternative design for the lowering of let-lambdas
to SDFG, than the design provided in #1589. The previous PR added a
separate state for the definition of let-symbols: the disadvantage of
that solution is that the generated SDFG does not clearly represent the
data dependency. The alternative design proposed in this PR uses a
nested SDFG to represent some local symbols in the lambda scope. The
lambda function is lowered inside the nested SDFG. We rely on the
simplify pass to remove unnecessary nested SDFGs.
  • Loading branch information
edopao authored Sep 3, 2024
1 parent 28c1ca8 commit 7824e38
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
LetSymbol: TypeAlias = tuple[gtir.Literal | gtir.SymRef, ts.FieldType | ts.ScalarType]
TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]


Expand All @@ -43,7 +42,7 @@ def __call__(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Creates the dataflow subgraph representing a GTIR primitive function.
Expand All @@ -55,9 +54,9 @@ def __call__(
sdfg: The SDFG where the primitive subgraph should be instantiated
state: The SDFG state where the result of the primitive function should be made available
sdfg_builder: The object responsible for visiting child nodes of the primitive node.
let_symbols: Mapping of symbols (i.e. lambda parameters and/or local constants
like the identity value in a reduction context) to temporary fields
or symbolic expressions.
reduce_identity: The value of the reduction identity, in case the primitive node
is visited in the context of a reduction expression. This value is used
by the `neighbors` primitive to provide the default value of skip neighbors.
Returns:
A list of data access nodes and the associated GT4Py data type, which provide
Expand All @@ -75,13 +74,13 @@ def _parse_arg_expr(
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr:
fields: list[TemporaryData] = sdfg_builder.visit(
node,
sdfg=sdfg,
head_state=state,
let_symbols=let_symbols,
reduce_identity=reduce_identity,
)

assert len(fields) == 1
Expand Down Expand Up @@ -116,7 +115,6 @@ def _create_temporary_field(
],
node_type: ts.FieldType,
output_desc: dace.data.Data,
output_field_type: ts.DataType,
) -> tuple[dace.nodes.AccessNode, ts.FieldType]:
domain_dims, domain_lbs, domain_ubs = zip(*domain)
field_dims = list(domain_dims)
Expand Down Expand Up @@ -153,7 +151,7 @@ def translate_as_field_op(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `as_fieldop` builtin function."""
assert isinstance(node, gtir.FunCall)
Expand All @@ -172,36 +170,17 @@ def translate_as_field_op(
domain = dace_fieldview_util.get_domain(domain_expr)
assert isinstance(node.type, ts.FieldType)

reduce_identity: Optional[gtir_to_tasklet.SymbolExpr] = None
if cpm.is_applied_reduce(stencil_expr.expr):
# 'reduce' is a reserved keyword of the DSL and we will never find a user-defined symbol
# with this name. Since 'reduce' will never collide with a user-defined symbol, it is safe
# to use it internally to store the reduce identity value as a let-symbol.
if "reduce" in let_symbols:
if reduce_identity is not None:
raise NotImplementedError("nested reductions not supported.")

# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr)

# we store the reduce identity value as a constant let-symbol
let_symbols = let_symbols | {
"reduce": (
gtir.Literal(value=str(reduce_identity.value), type=stencil_expr.expr.type),
reduce_identity.dtype,
)
}

elif "reduce" in let_symbols:
# a parent node is a reduction node, so we are visiting the current node in the context of a reduction
reduce_symbol, _ = let_symbols["reduce"]
assert isinstance(reduce_symbol, gtir.Literal)
reduce_identity = gtir_to_tasklet.SymbolExpr(
reduce_symbol.value, dace_fieldview_util.as_dace_type(reduce_symbol.type)
)

# first visit the list of arguments and build a symbol map
stencil_args = [
_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args
_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, reduce_identity)
for arg in node.args
]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
Expand All @@ -221,9 +200,7 @@ def translate_as_field_op(
last_node_connector = None

# allocate local temporary storage for the result field
field_node, field_type = _create_temporary_field(
sdfg, state, domain, node.type, output_desc, output_expr.dtype
)
field_node, field_type = _create_temporary_field(sdfg, state, domain, node.type, output_desc)

# assume tasklet with single output
output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain]
Expand Down Expand Up @@ -265,7 +242,7 @@ def translate_cond(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `cond` builtin function."""
assert cpm.is_call_to(node, "cond")
Expand Down Expand Up @@ -307,13 +284,13 @@ def translate_cond(
true_expr,
sdfg=sdfg,
head_state=true_state,
let_symbols=let_symbols,
reduce_identity=reduce_identity,
)
false_br_args = sdfg_builder.visit(
false_expr,
sdfg=sdfg,
head_state=false_state,
let_symbols=let_symbols,
reduce_identity=reduce_identity,
)

output_nodes = []
Expand Down Expand Up @@ -381,7 +358,7 @@ def translate_literal(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.Literal` node."""
assert isinstance(node, gtir.Literal)
Expand All @@ -397,24 +374,13 @@ def translate_symbol_ref(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, gtir.SymRef)

sym_value = str(node.id)
if sym_value in let_symbols:
let_node, sym_type = let_symbols[sym_value]
if isinstance(let_node, gtir.Literal):
# this branch handles the case a let-symbol is mapped to some constant value
return sdfg_builder.visit(let_node)
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node needs to
# be created in the state where they are accessed.
sym_value = str(let_node.id)
else:
sym_type = sdfg_builder.get_symbol_type(sym_value)
sym_type = sdfg_builder.get_symbol_type(sym_value)

# Create new access node in current state. It is possible that multiple
# access nodes are created in one state for the same data container.
Expand All @@ -434,5 +400,6 @@ def translate_symbol_ref(
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_field_op,
translate_cond,
translate_literal,
translate_symbol_ref,
]
Loading

0 comments on commit 7824e38

Please sign in to comment.