Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next][dace]: make if_ always execute branch exclusively #1846

Merged
merged 6 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 94 additions & 106 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Any,
Dict,
Final,
Iterable,
List,
Optional,
Protocol,
Expand All @@ -29,7 +30,7 @@

from gt4py import eve
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import builtins, ir as gtir
from gt4py.next.iterator import builtins as gtir_builtins, ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.program_processors.runners.dace import (
Expand Down Expand Up @@ -336,7 +337,6 @@ class LambdaToDataflow(eve.NodeVisitor):
sdfg: dace.SDFG
state: dace.SDFGState
subgraph_builder: gtir_sdfg.DataflowBuilder
scan_carry_symbol: Optional[gtir.Sym]
input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: [])
symbol_map: dict[
str,
Expand Down Expand Up @@ -603,6 +603,7 @@ def _visit_if_branch_arg(
if_branch_state: dace.SDFGState,
param_name: str,
arg: IteratorExpr | DataExpr,
deref_on_input_memlet: bool,
if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr],
) -> IteratorExpr | ValueExpr:
"""
Expand All @@ -613,35 +614,56 @@ def _visit_if_branch_arg(
if_branch_state: The state inside the nested SDFG where the if branch is lowered.
param_name: The parameter name of the input argument.
arg: The input argument expression.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. This means that the values are copied into a temporary storage which is passed into the nested SDFG.

I think the description of this option should have a bit more information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change you propose is not correct. We are not allocating temporary storage, we are just narrowing the memlet subset.

if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function.
"""
use_full_shape = False
if isinstance(arg, (MemletExpr, ValueExpr)):
arg_desc = arg.dc_node.desc(self.sdfg)
arg_expr = arg
arg_node = arg.dc_node
arg_desc = arg_node.desc(self.sdfg)
if isinstance(arg, MemletExpr):
assert arg.subset.num_elements() == 1
arg_desc = dace.data.Scalar(arg_desc.dtype)
else:
assert isinstance(arg_desc, dace.data.Scalar)
elif isinstance(arg, IteratorExpr):
arg_node = arg.field
arg_desc = arg_node.desc(self.sdfg)
arg_expr = MemletExpr(arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc))
arg_desc = arg.field.desc(self.sdfg)
if deref_on_input_memlet:
# If the iterator is just dereferenced inside the branch state,
# we can access the array outside the nested SDFG and pass the
# local data. This approach makes the data dependencies of nested
# structures more explicit and thus makes it easier for MapFusion
# to correctly infer the data dependencies.
memlet_subset = arg.get_memlet_subset(self.sdfg)
arg_expr = MemletExpr(arg.field, arg.gt_dtype, memlet_subset)
else:
# In order to shift the iterator inside the branch dataflow,
# we have to pass the full array shape.
arg_expr = MemletExpr(
arg.field, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc)
)
use_full_shape = True
else:
raise TypeError(f"Unexpected {arg} as input argument.")

if param_name in if_sdfg.arrays:
inner_desc = if_sdfg.data(param_name)
assert not inner_desc.transient
else:
if use_full_shape:
inner_desc = arg_desc.clone()
inner_desc.transient = False
elif isinstance(arg.gt_dtype, ts.ScalarType):
inner_desc = dace.data.Scalar(arg_desc.dtype)
else:
# for list of values, we retrieve the local size from the corresponding offset
assert arg.gt_dtype.offset_type is not None
offset_provider_type = self.subgraph_builder.get_offset_provider_type(
arg.gt_dtype.offset_type.value
)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
inner_desc = dace.data.Array(arg_desc.dtype, [offset_provider_type.max_neighbors])

if param_name in if_sdfg.arrays:
# the data desciptor was added by the visitor of the other branch expression
assert if_sdfg.data(param_name) == inner_desc
else:
if_sdfg.add_datadesc(param_name, inner_desc)
if_sdfg_input_memlets[param_name] = arg_expr

inner_node = if_branch_state.add_access(param_name)
if isinstance(arg, IteratorExpr):
if isinstance(arg, IteratorExpr) and use_full_shape:
return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices)
else:
return ValueExpr(inner_node, arg.gt_dtype)
Expand All @@ -652,6 +674,7 @@ def _visit_if_branch(
if_branch_state: dace.SDFGState,
expr: gtir.Expr,
if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr],
direct_deref_iterators: Iterable[str],
) -> tuple[
list[DataflowInputEdge],
tuple[DataflowOutputEdge | tuple[Any, ...], ...],
Expand All @@ -666,6 +689,7 @@ def _visit_if_branch(
if_branch_state: The state inside the nested SDFG where the if branch is lowered.
expr: The if branch expression to lower.
if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function.
direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift.

Returns:
A tuple containing:
Expand All @@ -682,15 +706,29 @@ def _visit_if_branch(
ptype = get_tuple_type(arg) # type: ignore[arg-type]
psymbol = im.sym(pname, ptype)
psymbol_tree = gtir_sdfg_utils.make_symbol_tree(pname, ptype)
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = gtx_utils.tree_map(
lambda tsym, targ: self._visit_if_branch_arg(
if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets
lambda tsym,
targ,
deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
tsym.id,
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
)
)(psymbol_tree, arg)
else:
psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr]
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = self._visit_if_branch_arg(
if_sdfg, if_branch_state, pname, arg, if_sdfg_input_memlets
if_sdfg,
if_branch_state,
pname,
arg,
deref_on_input_memlet,
if_sdfg_input_memlets,
)
lambda_args.append(inner_arg)
lambda_params.append(psymbol)
Expand Down Expand Up @@ -742,11 +780,6 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
Lowers an if-expression with exclusive branch execution into a nested SDFG,
in which each branch is lowered into a dataflow in a separate state and
the if-condition is represented as the inter-state edge condition.

Exclusive branch execution for local if expressions is meant to be used
in iterator view. Iterator view is required ONLY inside scan field operators.
For regular field operators, the fieldview behavior of if-expressions
corresponds to a local select, therefore it should be lowered to a tasklet.
"""

def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr:
Expand Down Expand Up @@ -805,9 +838,41 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg.add_scalar("__cond", dace.dtypes.bool)
input_memlets["__cond"] = condition_value

# Collect all field iterators that are shifted inside any of the then/else
# branch expressions. Iterator shift expressions require the field argument
# as iterator, therefore the corresponding array has to be passed with full
# shape into the nested SDFG where the if_ expression is lowered. When the
# branch expression simply does `deref` on the iterator, without any shifting,
# it corresponds to a direct element access. Such `deref` expressions can
# be lowered outside the nested SDFG, so that just the local value (a scalar
# or a list of values) is passed as input to the nested SDFG.
shifted_iterator_symbols = set()
for branch_expr in node.args[1:3]:
for shift_node in eve.walk_values(branch_expr).filter(
lambda x: cpm.is_applied_shift(x)
):
shifted_iterator_symbols |= (
eve.walk_values(shift_node)
.if_isinstance(gtir.SymRef)
.map(lambda x: str(x.id))
.filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr))
.to_set()
)
iterator_symbols = {
sym_name
for sym_name, sym_type in self.symbol_map.items()
if isinstance(sym_type, IteratorExpr)
}
direct_deref_iterators = (
set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols))
- shifted_iterator_symbols
)

for nstate, arg in zip([tstate, fstate], node.args[1:3]):
# visit each if-branch in the corresponding state of the nested SDFG
in_edges, output_tree = self._visit_if_branch(nsdfg, nstate, arg, input_memlets)
in_edges, output_tree = self._visit_if_branch(
nsdfg, nstate, arg, input_memlets, direct_deref_iterators
)
for edge in in_edges:
edge.connect(map_entry=None)

Expand Down Expand Up @@ -1511,7 +1576,7 @@ def _make_unstructured_shift(
def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr:
# convert builtin-index type to dace type
IndexDType: Final = gtx_dace_utils.as_dace_type(
ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))
ts.ScalarType(kind=getattr(ts.ScalarKind, gtir_builtins.INTEGER_INDEX_BUILTIN.upper()))
)

assert isinstance(node.fun, gtir.FunCall)
Expand Down Expand Up @@ -1637,87 +1702,13 @@ def _visit_tuple_get(
tuple_fields = self.visit(node.args[1])
return tuple_fields[index]

def requires_exclusive_if(self, node: gtir.FunCall) -> bool:
"""
The meaning of `if_` builtin function is unclear in GTIR.
In some context, it corresponds to a ternary operator where, depending on
the condition result, only one branch or the other should be executed,
because one of them is invalid. The typical case is the use of `if_` to
decide whether it is possible or not to access a shifted iterator, for
example when the condition expression calls `can_deref`.
The ternary operator is also used in iterator view, where the field arguments
are not necessarily both defined on the entire output domain (this behavior
should not appear in field view, because there the user code should use
`concat_where` instead of `where` for such cases). It is difficult to catch
such behavior, because it would require to know the exact domain of all
fields, which is not known at compile time. However, the iterator view
behavior should only appear inside scan field operators.
A different usage of `if_` expressions is selecting one argument value or
the other, where both arguments are defined on the output domain, therefore
always valid.
In order to simplify the SDFG and facilitate the optimization stage, we
try to avoid the ternary operator form when not needed. The reason is that
exclusive branch execution is represented in the SDFG as a conditional
state transition, which prevents fusion.
"""
assert cpm.is_call_to(node, "if_")
assert len(node.args) == 3

condition_vars = (
eve.walk_values(node.args[0])
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: x in self.symbol_map)
.to_set()
)

# first, check if any argument contains shift expressions that depend on the condition variables
for arg in node.args[1:3]:
shift_nodes = (
eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set()
)
for shift_node in shift_nodes:
shift_vars = (
eve.walk_values(shift_node)
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: x in self.symbol_map)
.to_set()
)
# require exclusive branch execution if any shift expression one of
# the if branches accesses a variable used in the condition expression
depend_vars = condition_vars.intersection(shift_vars)
if len(depend_vars) != 0:
return True

# secondly, check whether the `if_` branches access different sets of fields
# and this happens inside a scan field operator
if self.scan_carry_symbol is not None:
# the `if_` node is inside a scan stencil expression
scan_carry_var = str(self.scan_carry_symbol.id)
if scan_carry_var in condition_vars:
br1_vars, br2_vars = (
eve.walk_values(arg)
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr))
.to_set()
for arg in node.args[1:3]
)
if br1_vars != br2_vars:
# the two branches of the `if_` expression access different sets of fields,
# depending on the scan carry value
return True

return False

def visit_FunCall(
self, node: gtir.FunCall
) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]:
if cpm.is_call_to(node, "deref"):
return self._visit_deref(node)

elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node):
elif cpm.is_call_to(node, "if_"):
return self._visit_if(node)

elif cpm.is_call_to(node, "neighbors"):
Expand Down Expand Up @@ -1854,7 +1845,6 @@ def translate_lambda_to_dataflow(
| ValueExpr
| tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...]
],
scan_carry_symbol: Optional[gtir.Sym] = None,
) -> tuple[
list[DataflowInputEdge],
tuple[DataflowOutputEdge | tuple[Any, ...], ...],
Expand All @@ -1873,15 +1863,13 @@ def translate_lambda_to_dataflow(
sdfg_builder: Helper class to build the dataflow inside the given SDFG.
node: Lambda node to visit.
args: Arguments passed to lambda node.
scan_carry_symbol: When set, the lowering of `if_` expression will consider
using the ternary operator form with exclusive branch execution.

Returns:
A tuple of two elements:
- List of connections for data inputs to the dataflow.
- Tree representation of output data connections.
"""
taskgen = LambdaToDataflow(sdfg, state, sdfg_builder, scan_carry_symbol)
taskgen = LambdaToDataflow(sdfg, state, sdfg_builder)
lambda_output = taskgen.visit_let(node, args)

if isinstance(lambda_output, DataflowOutputEdge):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,7 @@ def init_scan_carry(sym: gtir.Sym) -> None:
# stil inside the 'compute' state, generate the dataflow representing the stencil
# to be applied on the horizontal domain
lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow(
nsdfg,
compute_state,
lambda_translator,
lambda_node,
stencil_args,
scan_carry_symbol=scan_carry_symbol,
nsdfg, compute_state, lambda_translator, lambda_node, stencil_args
)
# connect the dataflow input directly to the source data nodes, without passing through a map node;
# the reason is that the map for horizontal domain is outside the scan loop region
Expand Down