From 509038837b70de0fd764230cd67efc713f33903d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 3 Feb 2025 08:00:30 +0100 Subject: [PATCH] re-enable check for exclusive if --- .../runners/dace/gtir_dataflow.py | 76 ++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 34161ff266..04d362b834 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -1637,13 +1637,87 @@ def _visit_tuple_get( tuple_fields = self.visit(node.args[1]) return tuple_fields[index] + def requires_exclusive_if(self, node: gtir.FunCall) -> bool: + """ + The meaning of `if_` builtin function is unclear in GTIR. + In some context, it corresponds to a ternary operator where, depending on + the condition result, only one branch or the other should be executed, + because one of them is invalid. The typical case is the use of `if_` to + decide whether it is possible or not to access a shifted iterator, for + example when the condition expression calls `can_deref`. + The ternary operator is also used in iterator view, where the field arguments + are not necessarily both defined on the entire output domain (this behavior + should not appear in field view, because there the user code should use + `concat_where` instead of `where` for such cases). It is difficult to catch + such behavior, because it would require to know the exact domain of all + fields, which is not known at compile time. However, the iterator view + behavior should only appear inside scan field operators. + A different usage of `if_` expressions is selecting one argument value or + the other, where both arguments are defined on the output domain, therefore + always valid. + In order to simplify the SDFG and facilitate the optimization stage, we + try to avoid the ternary operator form when not needed. The reason is that + exclusive branch execution is represented in the SDFG as a conditional + state transition, which prevents fusion. + """ + assert cpm.is_call_to(node, "if_") + assert len(node.args) == 3 + + condition_vars = ( + eve.walk_values(node.args[0]) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: x in self.symbol_map) + .to_set() + ) + + # first, check if any argument contains shift expressions that depend on the condition variables + for arg in node.args[1:3]: + shift_nodes = ( + eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set() + ) + for shift_node in shift_nodes: + shift_vars = ( + eve.walk_values(shift_node) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: x in self.symbol_map) + .to_set() + ) + # require exclusive branch execution if any shift expression one of + # the if branches accesses a variable used in the condition expression + depend_vars = condition_vars.intersection(shift_vars) + if len(depend_vars) != 0: + return True + + # secondly, check whether the `if_` branches access different sets of fields + # and this happens inside a scan field operator + if self.scan_carry_symbol is not None: + # the `if_` node is inside a scan stencil expression + scan_carry_var = str(self.scan_carry_symbol.id) + if scan_carry_var in condition_vars: + br1_vars, br2_vars = ( + eve.walk_values(arg) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr)) + .to_set() + for arg in node.args[1:3] + ) + if br1_vars != br2_vars: + # the two branches of the `if_` expression access different sets of fields, + # depending on the scan carry value + return True + + return False + def visit_FunCall( self, node: gtir.FunCall ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) - elif cpm.is_call_to(node, "if_"): + elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node): return self._visit_if(node) elif cpm.is_call_to(node, "neighbors"):