diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 04d362b834..e6f33208e3 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -14,6 +14,7 @@ Any, Dict, Final, + Iterable, List, Optional, Protocol, @@ -29,7 +30,7 @@ from gt4py import eve from gt4py.next import common as gtx_common, utils as gtx_utils -from gt4py.next.iterator import builtins, ir as gtir +from gt4py.next.iterator import builtins as gtir_builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace import ( @@ -336,7 +337,6 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - scan_carry_symbol: Optional[gtir.Sym] input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) symbol_map: dict[ str, @@ -533,14 +533,17 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") - # default case: deref a field with one or more dimensions + # handle default case below: deref a field with one or more dimensions + + # when the indices are all dace symbolic expressions, the deref is lowered + # to a memlet, where the index is the memlet subset if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): # when all indices are symbolic expressions, we can perform direct field access through a memlet field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. + # when any of the indices is a runtime value (either a dynamic cartesian + # offset or a connectivity offset), the deref is lowered to a tasklet assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) assert len(field_desc.shape) == len(arg_expr.field_domain) field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] @@ -559,7 +562,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: for dim, index in field_indices ) deref_node = self._add_tasklet( - "runtime_deref", + "deref", {"field"} | set(index_connectors), {"val"}, code=f"val = field[{index_internals}]", @@ -603,6 +606,7 @@ def _visit_if_branch_arg( if_branch_state: dace.SDFGState, param_name: str, arg: IteratorExpr | DataExpr, + deref_on_input_memlet: bool, if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], ) -> IteratorExpr | ValueExpr: """ @@ -613,35 +617,56 @@ def _visit_if_branch_arg( if_branch_state: The state inside the nested SDFG where the if branch is lowered. param_name: The parameter name of the input argument. arg: The input argument expression. + deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. """ + use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): + arg_desc = arg.dc_node.desc(self.sdfg) arg_expr = arg - arg_node = arg.dc_node - arg_desc = arg_node.desc(self.sdfg) - if isinstance(arg, MemletExpr): - assert arg.subset.num_elements() == 1 - arg_desc = dace.data.Scalar(arg_desc.dtype) - else: - assert isinstance(arg_desc, dace.data.Scalar) elif isinstance(arg, IteratorExpr): - arg_node = arg.field - arg_desc = arg_node.desc(self.sdfg) - arg_expr = MemletExpr(arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc)) + arg_desc = arg.field.desc(self.sdfg) + if deref_on_input_memlet: + # If the iterator is just dereferenced inside the branch state, + # we can access the array outside the nested SDFG and pass the + # local data. This approach makes the data dependencies of nested + # structures more explicit and thus makes it easier for MapFusion + # to correctly infer the data dependencies. + memlet_subset = arg.get_memlet_subset(self.sdfg) + arg_expr = MemletExpr(arg.field, arg.gt_dtype, memlet_subset) + else: + # In order to shift the iterator inside the branch dataflow, + # we have to pass the full array shape. + arg_expr = MemletExpr( + arg.field, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc) + ) + use_full_shape = True else: raise TypeError(f"Unexpected {arg} as input argument.") - if param_name in if_sdfg.arrays: - inner_desc = if_sdfg.data(param_name) - assert not inner_desc.transient - else: + if use_full_shape: inner_desc = arg_desc.clone() inner_desc.transient = False + elif isinstance(arg.gt_dtype, ts.ScalarType): + inner_desc = dace.data.Scalar(arg_desc.dtype) + else: + # for list of values, we retrieve the local size from the corresponding offset + assert arg.gt_dtype.offset_type is not None + offset_provider_type = self.subgraph_builder.get_offset_provider_type( + arg.gt_dtype.offset_type.value + ) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inner_desc = dace.data.Array(arg_desc.dtype, [offset_provider_type.max_neighbors]) + + if param_name in if_sdfg.arrays: + # the data desciptor was added by the visitor of the other branch expression + assert if_sdfg.data(param_name) == inner_desc + else: if_sdfg.add_datadesc(param_name, inner_desc) if_sdfg_input_memlets[param_name] = arg_expr inner_node = if_branch_state.add_access(param_name) - if isinstance(arg, IteratorExpr): + if isinstance(arg, IteratorExpr) and use_full_shape: return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: return ValueExpr(inner_node, arg.gt_dtype) @@ -652,6 +677,7 @@ def _visit_if_branch( if_branch_state: dace.SDFGState, expr: gtir.Expr, if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + direct_deref_iterators: Iterable[str], ) -> tuple[ list[DataflowInputEdge], tuple[DataflowOutputEdge | tuple[Any, ...], ...], @@ -666,6 +692,7 @@ def _visit_if_branch( if_branch_state: The state inside the nested SDFG where the if branch is lowered. expr: The if branch expression to lower. if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. Returns: A tuple containing: @@ -682,15 +709,29 @@ def _visit_if_branch( ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) psymbol_tree = gtir_sdfg_utils.make_symbol_tree(pname, ptype) + deref_on_input_memlet = pname in direct_deref_iterators inner_arg = gtx_utils.tree_map( - lambda tsym, targ: self._visit_if_branch_arg( - if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets + lambda tsym, + targ, + deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg( + if_sdfg, + if_branch_state, + tsym.id, + targ, + deref_on_input_memlet, + if_sdfg_input_memlets, ) )(psymbol_tree, arg) else: psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + deref_on_input_memlet = pname in direct_deref_iterators inner_arg = self._visit_if_branch_arg( - if_sdfg, if_branch_state, pname, arg, if_sdfg_input_memlets + if_sdfg, + if_branch_state, + pname, + arg, + deref_on_input_memlet, + if_sdfg_input_memlets, ) lambda_args.append(inner_arg) lambda_params.append(psymbol) @@ -742,11 +783,6 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A Lowers an if-expression with exclusive branch execution into a nested SDFG, in which each branch is lowered into a dataflow in a separate state and the if-condition is represented as the inter-state edge condition. - - Exclusive branch execution for local if expressions is meant to be used - in iterator view. Iterator view is required ONLY inside scan field operators. - For regular field operators, the fieldview behavior of if-expressions - corresponds to a local select, therefore it should be lowered to a tasklet. """ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: @@ -805,9 +841,41 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg.add_scalar("__cond", dace.dtypes.bool) input_memlets["__cond"] = condition_value + # Collect all field iterators that are shifted inside any of the then/else + # branch expressions. Iterator shift expressions require the field argument + # as iterator, therefore the corresponding array has to be passed with full + # shape into the nested SDFG where the if_ expression is lowered. When the + # branch expression simply does `deref` on the iterator, without any shifting, + # it corresponds to a direct element access. Such `deref` expressions can + # be lowered outside the nested SDFG, so that just the local value (a scalar + # or a list of values) is passed as input to the nested SDFG. + shifted_iterator_symbols = set() + for branch_expr in node.args[1:3]: + for shift_node in eve.walk_values(branch_expr).filter( + lambda x: cpm.is_applied_shift(x) + ): + shifted_iterator_symbols |= ( + eve.walk_values(shift_node) + .if_isinstance(gtir.SymRef) + .map(lambda x: str(x.id)) + .filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr)) + .to_set() + ) + iterator_symbols = { + sym_name + for sym_name, sym_type in self.symbol_map.items() + if isinstance(sym_type, IteratorExpr) + } + direct_deref_iterators = ( + set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols)) + - shifted_iterator_symbols + ) + for nstate, arg in zip([tstate, fstate], node.args[1:3]): # visit each if-branch in the corresponding state of the nested SDFG - in_edges, output_tree = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) + in_edges, output_tree = self._visit_if_branch( + nsdfg, nstate, arg, input_memlets, direct_deref_iterators + ) for edge in in_edges: edge.connect(map_entry=None) @@ -1511,7 +1579,7 @@ def _make_unstructured_shift( def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type IndexDType: Final = gtx_dace_utils.as_dace_type( - ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) + ts.ScalarType(kind=getattr(ts.ScalarKind, gtir_builtins.INTEGER_INDEX_BUILTIN.upper())) ) assert isinstance(node.fun, gtir.FunCall) @@ -1637,87 +1705,13 @@ def _visit_tuple_get( tuple_fields = self.visit(node.args[1]) return tuple_fields[index] - def requires_exclusive_if(self, node: gtir.FunCall) -> bool: - """ - The meaning of `if_` builtin function is unclear in GTIR. - In some context, it corresponds to a ternary operator where, depending on - the condition result, only one branch or the other should be executed, - because one of them is invalid. The typical case is the use of `if_` to - decide whether it is possible or not to access a shifted iterator, for - example when the condition expression calls `can_deref`. - The ternary operator is also used in iterator view, where the field arguments - are not necessarily both defined on the entire output domain (this behavior - should not appear in field view, because there the user code should use - `concat_where` instead of `where` for such cases). It is difficult to catch - such behavior, because it would require to know the exact domain of all - fields, which is not known at compile time. However, the iterator view - behavior should only appear inside scan field operators. - A different usage of `if_` expressions is selecting one argument value or - the other, where both arguments are defined on the output domain, therefore - always valid. - In order to simplify the SDFG and facilitate the optimization stage, we - try to avoid the ternary operator form when not needed. The reason is that - exclusive branch execution is represented in the SDFG as a conditional - state transition, which prevents fusion. - """ - assert cpm.is_call_to(node, "if_") - assert len(node.args) == 3 - - condition_vars = ( - eve.walk_values(node.args[0]) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: x in self.symbol_map) - .to_set() - ) - - # first, check if any argument contains shift expressions that depend on the condition variables - for arg in node.args[1:3]: - shift_nodes = ( - eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set() - ) - for shift_node in shift_nodes: - shift_vars = ( - eve.walk_values(shift_node) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: x in self.symbol_map) - .to_set() - ) - # require exclusive branch execution if any shift expression one of - # the if branches accesses a variable used in the condition expression - depend_vars = condition_vars.intersection(shift_vars) - if len(depend_vars) != 0: - return True - - # secondly, check whether the `if_` branches access different sets of fields - # and this happens inside a scan field operator - if self.scan_carry_symbol is not None: - # the `if_` node is inside a scan stencil expression - scan_carry_var = str(self.scan_carry_symbol.id) - if scan_carry_var in condition_vars: - br1_vars, br2_vars = ( - eve.walk_values(arg) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr)) - .to_set() - for arg in node.args[1:3] - ) - if br1_vars != br2_vars: - # the two branches of the `if_` expression access different sets of fields, - # depending on the scan carry value - return True - - return False - def visit_FunCall( self, node: gtir.FunCall ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) - elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node): + elif cpm.is_call_to(node, "if_"): return self._visit_if(node) elif cpm.is_call_to(node, "neighbors"): @@ -1854,7 +1848,6 @@ def translate_lambda_to_dataflow( | ValueExpr | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] ], - scan_carry_symbol: Optional[gtir.Sym] = None, ) -> tuple[ list[DataflowInputEdge], tuple[DataflowOutputEdge | tuple[Any, ...], ...], @@ -1873,15 +1866,13 @@ def translate_lambda_to_dataflow( sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. - scan_carry_symbol: When set, the lowering of `if_` expression will consider - using the ternary operator form with exclusive branch execution. Returns: A tuple of two elements: - List of connections for data inputs to the dataflow. - Tree representation of output data connections. """ - taskgen = LambdaToDataflow(sdfg, state, sdfg_builder, scan_carry_symbol) + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) lambda_output = taskgen.visit_let(node, args) if isinstance(lambda_output, DataflowOutputEdge): diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index 743b4d33e4..da10d4bddd 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -414,12 +414,7 @@ def init_scan_carry(sym: gtir.Sym) -> None: # stil inside the 'compute' state, generate the dataflow representing the stencil # to be applied on the horizontal domain lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( - nsdfg, - compute_state, - lambda_translator, - lambda_node, - stencil_args, - scan_carry_symbol=scan_carry_symbol, + nsdfg, compute_state, lambda_translator, lambda_node, stencil_args ) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region