diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 1219ac108f..6d3c60a9de 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -31,7 +31,6 @@ IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes -LetSymbol: TypeAlias = tuple[gtir.Literal | gtir.SymRef, ts.FieldType | ts.ScalarType] TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] @@ -43,7 +42,7 @@ def __call__( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -55,9 +54,9 @@ def __call__( sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available sdfg_builder: The object responsible for visiting child nodes of the primitive node. - let_symbols: Mapping of symbols (i.e. lambda parameters and/or local constants - like the identity value in a reduction context) to temporary fields - or symbolic expressions. + reduce_identity: The value of the reduction identity, in case the primitive node + is visited in the context of a reduction expression. This value is used + by the `neighbors` primitive to provide the default value of skip neighbors. Returns: A list of data access nodes and the associated GT4Py data type, which provide @@ -75,13 +74,13 @@ def _parse_arg_expr( domain: list[ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] ], - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: fields: list[TemporaryData] = sdfg_builder.visit( node, sdfg=sdfg, head_state=state, - let_symbols=let_symbols, + reduce_identity=reduce_identity, ) assert len(fields) == 1 @@ -116,7 +115,6 @@ def _create_temporary_field( ], node_type: ts.FieldType, output_desc: dace.data.Data, - output_field_type: ts.DataType, ) -> tuple[dace.nodes.AccessNode, ts.FieldType]: domain_dims, domain_lbs, domain_ubs = zip(*domain) field_dims = list(domain_dims) @@ -153,7 +151,7 @@ def translate_as_field_op( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: """Generates the dataflow subgraph for the `as_fieldop` builtin function.""" assert isinstance(node, gtir.FunCall) @@ -172,36 +170,17 @@ def translate_as_field_op( domain = dace_fieldview_util.get_domain(domain_expr) assert isinstance(node.type, ts.FieldType) - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr] = None if cpm.is_applied_reduce(stencil_expr.expr): - # 'reduce' is a reserved keyword of the DSL and we will never find a user-defined symbol - # with this name. Since 'reduce' will never collide with a user-defined symbol, it is safe - # to use it internally to store the reduce identity value as a let-symbol. - if "reduce" in let_symbols: + if reduce_identity is not None: raise NotImplementedError("nested reductions not supported.") # the reduce identity value is used to fill the skip values in neighbors list _, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr) - # we store the reduce identity value as a constant let-symbol - let_symbols = let_symbols | { - "reduce": ( - gtir.Literal(value=str(reduce_identity.value), type=stencil_expr.expr.type), - reduce_identity.dtype, - ) - } - - elif "reduce" in let_symbols: - # a parent node is a reduction node, so we are visiting the current node in the context of a reduction - reduce_symbol, _ = let_symbols["reduce"] - assert isinstance(reduce_symbol, gtir.Literal) - reduce_identity = gtir_to_tasklet.SymbolExpr( - reduce_symbol.value, dace_fieldview_util.as_dace_type(reduce_symbol.type) - ) - # first visit the list of arguments and build a symbol map stencil_args = [ - _parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args + _parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, reduce_identity) + for arg in node.args ] # represent the field operator as a mapped tasklet graph, which will range over the field domain @@ -221,9 +200,7 @@ def translate_as_field_op( last_node_connector = None # allocate local temporary storage for the result field - field_node, field_type = _create_temporary_field( - sdfg, state, domain, node.type, output_desc, output_expr.dtype - ) + field_node, field_type = _create_temporary_field(sdfg, state, domain, node.type, output_desc) # assume tasklet with single output output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain] @@ -265,7 +242,7 @@ def translate_cond( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: """Generates the dataflow subgraph for the `cond` builtin function.""" assert cpm.is_call_to(node, "cond") @@ -307,13 +284,13 @@ def translate_cond( true_expr, sdfg=sdfg, head_state=true_state, - let_symbols=let_symbols, + reduce_identity=reduce_identity, ) false_br_args = sdfg_builder.visit( false_expr, sdfg=sdfg, head_state=false_state, - let_symbols=let_symbols, + reduce_identity=reduce_identity, ) output_nodes = [] @@ -381,7 +358,7 @@ def translate_literal( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: """Generates the dataflow subgraph for a `ir.Literal` node.""" assert isinstance(node, gtir.Literal) @@ -397,24 +374,13 @@ def translate_symbol_ref( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - let_symbols: dict[str, LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: """Generates the dataflow subgraph for a `ir.SymRef` node.""" assert isinstance(node, gtir.SymRef) sym_value = str(node.id) - if sym_value in let_symbols: - let_node, sym_type = let_symbols[sym_value] - if isinstance(let_node, gtir.Literal): - # this branch handles the case a let-symbol is mapped to some constant value - return sdfg_builder.visit(let_node) - # The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary - # data container. These symbols are visited and initialized in a state - # that preceeds the current state, therefore a new access node needs to - # be created in the state where they are accessed. - sym_value = str(let_node.id) - else: - sym_type = sdfg_builder.get_symbol_type(sym_value) + sym_type = sdfg_builder.get_symbol_type(sym_value) # Create new access node in current state. It is possible that multiple # access nodes are created in one state for the same data container. @@ -434,5 +400,6 @@ def translate_symbol_ref( __primitive_translators: list[PrimitiveTranslator] = [ translate_as_field_op, translate_cond, + translate_literal, translate_symbol_ref, ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index bf1a017312..75a5aa07e3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -16,9 +16,10 @@ import abc import dataclasses -from typing import Any, Dict, List, Protocol, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace +import dace.transformation.dataflow as dace_dataflow from gt4py import eve from gt4py.eve import concepts @@ -28,6 +29,7 @@ from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, + gtir_to_tasklet, utility as dace_fieldview_util, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -183,10 +185,6 @@ def _add_storage( else: raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.") - # TODO: unclear why mypy complains about incompatible types - assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) - self.global_symbols[name] = symbol_type - def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: """ Add temporary storage (aka transient) for data containers used as GTIR temporaries. @@ -211,7 +209,7 @@ def _visit_expression( to have the same memory layout as the target array. """ results: list[gtir_builtin_translators.TemporaryData] = self.visit( - node, sdfg=sdfg, head_state=head_state, let_symbols={} + node, sdfg=sdfg, head_state=head_state, reduce_identity=None ) field_nodes = [] @@ -227,6 +225,32 @@ def _visit_expression( return field_nodes + def _add_sdfg_params(self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym]) -> None: + """Helper function to add storage for node parameters and connectivity tables.""" + + # add non-transient arrays and/or SDFG symbols for the program arguments + for param in node_params: + pname = str(param.id) + assert isinstance(param.type, (ts.FieldType, ts.ScalarType)) + self._add_storage(sdfg, pname, param.type, transient=False) + self.global_symbols[pname] = param.type + + # add SDFG storage for connectivity tables + for offset, offset_provider in dace_fieldview_util.filter_connectivities( + self.offset_provider + ).items(): + scalar_kind = tt.get_scalar_kind(offset_provider.index_type) + local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) + # We store all connectivity tables as transient arrays here; later, while building + # the field operator expressions, we change to non-transient (i.e. allocated externally) + # the tables that are actually used. This way, we avoid adding SDFG arguments for + # the connectivity tables that are not used. The remaining unused transient arrays + # are removed by the dace simplify pass. + self._add_storage(sdfg, dace_fieldview_util.connectivity_identifier(offset), type_) + def visit_Program(self, node: gtir.Program) -> dace.SDFG: """Translates `ir.Program` to `dace.SDFG`. @@ -254,26 +278,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: else: head_state = entry_state - # add non-transient arrays and/or SDFG symbols for the program arguments - for param in node.params: - assert isinstance(param.type, ts.DataType) - self._add_storage(sdfg, str(param.id), param.type, transient=False) - - # add SDFG storage for connectivity tables - for offset, offset_provider in dace_fieldview_util.filter_connectivities( - self.offset_provider - ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) - type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) - ) - # We store all connectivity tables as transient arrays here; later, while building - # the field operator expressions, we change to non-transient (i.e. allocated extrenally) - # the tables that are actually used. This way, we avoid adding SDFG arguments for - # the connectivity tables that are not used. The remaining unused transient arrays - # are removed by the dace simplify pass. - self._add_storage(sdfg, dace_fieldview_util.connectivity_identifier(offset), type_) + self._add_sdfg_params(sdfg, node.params) # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): @@ -331,47 +336,34 @@ def visit_FunCall( node: gtir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState, - let_symbols: dict[str, gtir_builtin_translators.LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "cond"): return gtir_builtin_translators.translate_cond( - node, sdfg, head_state, self, let_symbols + node, sdfg, head_state, self, reduce_identity ) elif cpm.is_call_to(node.fun, "as_fieldop"): return gtir_builtin_translators.translate_as_field_op( - node, sdfg, head_state, self, let_symbols + node, sdfg, head_state, self, reduce_identity ) elif isinstance(node.fun, gtir.Lambda): - # We use a separate state to ensure that the lambda arguments are evaluated - # before the computation starts. This is required in case the let-symbols - # are used in conditional branch execution, which happens in different states. - lambda_state = sdfg.add_state_before(head_state, f"{head_state.label}_symbols") - node_args = [] for arg in node.args: node_args.extend( self.visit( arg, sdfg=sdfg, - head_state=lambda_state, - let_symbols=let_symbols, + head_state=head_state, + reduce_identity=reduce_identity, ) ) - # some cleanup: remove isolated nodes for program arguments in lambda state - isolated_node_args = [node for node, _ in node_args if lambda_state.degree(node) == 0] - assert all( - isinstance(node, dace.nodes.AccessNode) and node.data in self.global_symbols - for node in isolated_node_args - ) - lambda_state.remove_nodes_from(isolated_node_args) - return self.visit( node.fun, sdfg=sdfg, head_state=head_state, - let_symbols=let_symbols, + reduce_identity=reduce_identity, args=node_args, ) else: @@ -382,39 +374,122 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - let_symbols: dict[str, gtir_builtin_translators.LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], args: list[gtir_builtin_translators.TemporaryData], ) -> list[gtir_builtin_translators.TemporaryData]: """ - Translates a `Lambda` node to a tasklet subgraph in the current SDFG state. + Translates a `Lambda` node to a nested SDFG in the current state. All arguments to lambda functions are fields (i.e. `as_fieldop`, field or scalar `gtir.SymRef`, - nested let-lambdas thereof). The dictionary called `let_symbols` maps the lambda parameters - to symbols, e.g. temporary fields or program arguments. If the lambda has a parameter whose name - is already present in `let_symbols`, i.e. a paramater with the same name as a previously defined - symbol, the parameter will shadow the previous symbol during traversal of the lambda expression. + nested let-lambdas thereof). The reason for creating a nested SDFG is to define local symbols + (the lambda paremeters) that map to parent fields, either program arguments or temporary fields. + + If the lambda has a parameter whose name is already present in `GTIRToSDFG.global_symbols`, + i.e. a lambda parameter with the same name as a symbol in scope, the parameter will shadow + the previous symbol during traversal of the lambda expression. """ - lambda_symbols = let_symbols | { - str(p.id): (gtir.SymRef(id=temp_node.data), type_) - for p, (temp_node, type_) in zip(node.params, args, strict=True) + + lambda_args_mapping = {str(p.id): arg for p, arg in zip(node.params, args, strict=True)} + + # inherit symbols from parent scope but eventually override with local symbols + lambda_symbols = self.global_symbols | { + pname: type_ for pname, (_, type_) in lambda_args_mapping.items() } - return self.visit( + nsdfg = dace.SDFG(f"{sdfg.label}_nested") + nstate = nsdfg.add_state("lambda") + + self._add_sdfg_params( + nsdfg, + [gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items()], + ) + + lambda_nodes = GTIRToSDFG(self.offset_provider, lambda_symbols.copy()).visit( node.expr, - sdfg=sdfg, - head_state=head_state, - let_symbols=lambda_symbols, + sdfg=nsdfg, + head_state=nstate, + reduce_identity=reduce_identity, ) + connectivity_arrays = { + dace_fieldview_util.connectivity_identifier(offset) + for offset in dace_fieldview_util.filter_connectivities(self.offset_provider) + } + nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} + + input_memlets = {} + for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): + if nsdfg_datadesc.transient: + continue + datadesc: Optional[dace.dtypes.Array] = None + if nsdfg_dataname in lambda_args_mapping: + src_node, _ = lambda_args_mapping[nsdfg_dataname] + dataname = src_node.data + datadesc = src_node.desc(sdfg) + + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], + [*datadesc.shape, *datadesc.strides], + strict=True, + ) + if isinstance(nested_symbol, dace.symbol) + } + else: + dataname = nsdfg_dataname + datadesc = sdfg.arrays[nsdfg_dataname] + # ensure that connectivity tables are non-transient arrays in parent SDFG + if dataname in connectivity_arrays: + datadesc.transient = False + + if datadesc: + input_memlets[nsdfg_dataname] = dace.Memlet.from_array(dataname, datadesc) + + nsdfg_node = head_state.add_nested_sdfg( + nsdfg, + parent=sdfg, + inputs=set(input_memlets.keys()), + outputs=set(node.data for node, _ in lambda_nodes), + symbol_mapping=nsdfg_symbols_mapping, + debuginfo=dace_fieldview_util.debug_info(node, default=sdfg.debuginfo), + ) + + for connector, memlet in input_memlets.items(): + if connector in lambda_args_mapping: + src_node, _ = lambda_args_mapping[connector] + else: + src_node = head_state.add_access(memlet.data) + + head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) + + results = [] + for lambda_node, type_ in lambda_nodes: + connector = lambda_node.data + desc = lambda_node.desc(nsdfg) + # make lambda result non-transient and map it to external temporary + desc.transient = False + # isolated access node will make validation fail + if nstate.degree(lambda_node) == 0: + nstate.remove_node(lambda_node) + temp, _ = sdfg.add_temp_transient_like(desc) + dst_node = head_state.add_access(temp) + head_state.add_edge( + nsdfg_node, connector, dst_node, None, dace.Memlet.from_array(temp, desc) + ) + results.append((dst_node, type_)) + + return results + def visit_Literal( self, node: gtir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState, - let_symbols: dict[str, gtir_builtin_translators.LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[gtir_builtin_translators.TemporaryData]: return gtir_builtin_translators.translate_literal( - node, sdfg, head_state, self, let_symbols={} + node, sdfg, head_state, self, reduce_identity=None ) def visit_SymRef( @@ -422,10 +497,10 @@ def visit_SymRef( node: gtir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState, - let_symbols: dict[str, gtir_builtin_translators.LetSymbol], + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[gtir_builtin_translators.TemporaryData]: return gtir_builtin_translators.translate_symbol_ref( - node, sdfg, head_state, self, let_symbols + node, sdfg, head_state, self, reduce_identity=None ) @@ -452,5 +527,9 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) + # nested-SDFGs for let-lambda may contain unused symbols, in which case + # we can remove unnecesssary data connectors (not done by dace simplify pass) + sdfg.apply_transformations_repeated(dace_dataflow.PruneConnectors) + sdfg.simplify() return sdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 55b037b1f6..0bd86fbf09 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -104,7 +104,7 @@ def test_gtir_cast(): IFTYPE_FLOAT32 = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) IFTYPE_BOOL = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)) testee = gtir.Program( - id="test_gtir_cast", + id="gtir_cast", function_definitions=[], params=[ gtir.Sym(id="x", type=IFTYPE), @@ -335,7 +335,7 @@ def test_gtir_cond(): gtir.SetAt( expr=im.op_as_fieldop("plus", domain)( "x", - im.call("cond")( + im.cond( im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), @@ -376,10 +376,10 @@ def test_gtir_cond_nested(): declarations=[], body=[ gtir.SetAt( - expr=im.call("cond")( + expr=im.cond( gtir.SymRef(id="pred_1"), im.op_as_fieldop("plus", domain)("x", 1.0), - im.call("cond")( + im.cond( gtir.SymRef(id="pred_2"), im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), @@ -594,9 +594,7 @@ def test_gtir_cartesian_shift_right(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) - FSYMBOLS_tmp = FSYMBOLS.copy() - FSYMBOLS_tmp["__x_offset_stride_0"] = 1 - sdfg(a, a_offset, b, **FSYMBOLS_tmp) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_stride_0=1) assert np.allclose(a[:-OFFSET] + DELTA, b[OFFSET:]) @@ -905,12 +903,7 @@ def test_gtir_neighbors_as_output(): declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges"), + expr=im.as_fieldop_neighbors("V2E", "edges", vertex_domain), domain=v2e_domain, target=gtir.SymRef(id="v2e_field"), ) @@ -963,14 +956,7 @@ def test_gtir_reduce(): ), vertex_domain, ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges") - ) + )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1038,14 +1024,7 @@ def test_gtir_reduce_with_skip_values(): ), vertex_domain, ) - )( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges") - ) + )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1129,18 +1108,8 @@ def test_gtir_reduce_dot_product(): ) )( im.op_as_fieldop("multiplies", vertex_domain)( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges"), - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - ) - )("edges"), + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ), ), domain=vertex_domain, @@ -1196,16 +1165,10 @@ def test_gtir_reduce_with_cond_neighbors(): ), vertex_domain, )( - im.call("cond")( + im.cond( gtir.SymRef(id="pred"), - im.as_fieldop( - im.lambda_("it")(im.neighbors("V2E_FULL", "it")), - vertex_domain, - )("edges"), - im.as_fieldop( - im.lambda_("it")(im.neighbors("V2E", "it")), - vertex_domain, - )("edges"), + im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) ), domain=vertex_domain, @@ -1309,6 +1272,73 @@ def test_gtir_let_lambda(): assert np.allclose(b, a * 8) +def test_gtir_let_lambda_with_connectivity(): + C2E_neighbor_idx = 1 + C2V_neighbor_idx = 2 + cell_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Cell.value), 0, "ncells"), + ) + + connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + assert isinstance(connectivity_C2E, gtx_common.NeighborTable) + connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + assert isinstance(connectivity_C2V, gtx_common.NeighborTable) + + testee = gtir.Program( + id="let_lambda_with_connectivity", + function_definitions=[], + params=[ + gtir.Sym(id="cells", type=CFTYPE), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="ncells", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "x1", + im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("C2E", C2E_neighbor_idx)("it"))), + cell_domain, + )("edges"), + )( + im.let( + "x2", + im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("C2V", C2V_neighbor_idx)("it"))), + cell_domain, + )("vertices"), + )(im.op_as_fieldop("plus", cell_domain)("x1", "x2")) + ), + domain=cell_domain, + target=gtir.SymRef(id="cells"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v = np.random.rand(SIMPLE_MESH.num_vertices) + c = np.empty(SIMPLE_MESH.num_cells) + ref = ( + e[connectivity_C2E.table[:, C2E_neighbor_idx]] + + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + ) + + sdfg( + cells=c, + edges=e, + vertices=v, + connectivity_C2E=connectivity_C2E.table, + connectivity_C2V=connectivity_C2V.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(c, ref) + + def test_gtir_let_lambda_with_cond(): domain = im.call("cartesian_domain")( im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") @@ -1327,7 +1357,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.call("cond")( + im.cond( gtir.SymRef(id="pred"), im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"),