Skip to content

Commit

Permalink
re-enable check for exclusive if
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Feb 4, 2025
1 parent 7e622f9 commit 5090388
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,13 +1637,87 @@ def _visit_tuple_get(
tuple_fields = self.visit(node.args[1])
return tuple_fields[index]

def requires_exclusive_if(self, node: gtir.FunCall) -> bool:
"""
The meaning of `if_` builtin function is unclear in GTIR.
In some context, it corresponds to a ternary operator where, depending on
the condition result, only one branch or the other should be executed,
because one of them is invalid. The typical case is the use of `if_` to
decide whether it is possible or not to access a shifted iterator, for
example when the condition expression calls `can_deref`.
The ternary operator is also used in iterator view, where the field arguments
are not necessarily both defined on the entire output domain (this behavior
should not appear in field view, because there the user code should use
`concat_where` instead of `where` for such cases). It is difficult to catch
such behavior, because it would require to know the exact domain of all
fields, which is not known at compile time. However, the iterator view
behavior should only appear inside scan field operators.
A different usage of `if_` expressions is selecting one argument value or
the other, where both arguments are defined on the output domain, therefore
always valid.
In order to simplify the SDFG and facilitate the optimization stage, we
try to avoid the ternary operator form when not needed. The reason is that
exclusive branch execution is represented in the SDFG as a conditional
state transition, which prevents fusion.
"""
assert cpm.is_call_to(node, "if_")
assert len(node.args) == 3

condition_vars = (
eve.walk_values(node.args[0])
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: x in self.symbol_map)
.to_set()
)

# first, check if any argument contains shift expressions that depend on the condition variables
for arg in node.args[1:3]:
shift_nodes = (
eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set()
)
for shift_node in shift_nodes:
shift_vars = (
eve.walk_values(shift_node)
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: x in self.symbol_map)
.to_set()
)
# require exclusive branch execution if any shift expression one of
# the if branches accesses a variable used in the condition expression
depend_vars = condition_vars.intersection(shift_vars)
if len(depend_vars) != 0:
return True

# secondly, check whether the `if_` branches access different sets of fields
# and this happens inside a scan field operator
if self.scan_carry_symbol is not None:
# the `if_` node is inside a scan stencil expression
scan_carry_var = str(self.scan_carry_symbol.id)
if scan_carry_var in condition_vars:
br1_vars, br2_vars = (
eve.walk_values(arg)
.if_isinstance(gtir.SymRef)
.map(lambda node: str(node.id))
.filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr))
.to_set()
for arg in node.args[1:3]
)
if br1_vars != br2_vars:
# the two branches of the `if_` expression access different sets of fields,
# depending on the scan carry value
return True

return False

def visit_FunCall(
self, node: gtir.FunCall
) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]:
if cpm.is_call_to(node, "deref"):
return self._visit_deref(node)

elif cpm.is_call_to(node, "if_"):
elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node):
return self._visit_if(node)

elif cpm.is_call_to(node, "neighbors"):
Expand Down

0 comments on commit 5090388

Please sign in to comment.