diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index 0fe776c3ee..6b2a32c063 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -39,31 +39,28 @@ def get_domain_indices( - dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None + dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]] ) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying - an optional offset in each dimension as start index. + an optional origin in each dimension as start index. Args: dims: The field dimensions. - offsets: The range start index in each dimension. + origin: The domain start index in each dimension. If set to `None`, assume all zeros. Returns: A list of indices for field access in dace arrays. As this list is returned as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before being used in memlet subset because ranges are better supported throughout DaCe. """ - index_variables = [dace.symbolic.SymExpr(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims] - if offsets is None: - return dace_subsets.Indices(index_variables) - else: - return dace_subsets.Indices( - [ - index - offset if offset != 0 else index - for index, offset in zip(index_variables, offsets, strict=True) - ] - ) + index_variables = [ + dace.symbolic.pystr_to_symbolic(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims + ] + origin = [0] * len(index_variables) if origin is None else origin + return dace_subsets.Indices( + [index - start_index for index, start_index in zip(index_variables, origin, strict=True)] + ) @dataclasses.dataclass(frozen=True) @@ -78,18 +75,58 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. - offset: List of index offsets, in each dimension, when the dimension range - does not start from zero; assume zero offset, if not set. + origin: Tuple of start indices, in each dimension, for `FieldType` data. + Pass an empty tuple for `ScalarType` data or zero-dimensional fields. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType - offset: Optional[list[dace.symbolic.SymExpr]] + origin: tuple[dace.symbolic.SymbolicType, ...] + + def __post_init__(self) -> None: + """Implements a sanity check on the constructed data type.""" + assert ( + len(self.origin) == 0 + if isinstance(self.gt_type, ts.ScalarType) + else len(self.origin) == len(self.gt_type.dims) + ) + + def map_to_parent_sdfg( + self, + sdfg_builder: gtir_sdfg.SDFGBuilder, + inner_sdfg: dace.SDFG, + outer_sdfg: dace.SDFG, + outer_sdfg_state: dace.SDFGState, + symbol_mapping: dict[str, dace.symbolic.SymbolicType], + ) -> FieldopData: + """ + Make the data descriptor which 'self' refers to, and which is located inside + a NestedSDFG, available in its parent SDFG. - def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: - """Create a copy of this data descriptor with a different access node.""" - assert data_node != self.dc_node - return FieldopData(data_node, self.gt_type, self.offset) + Thus, it turns 'self' into a non-transient array and creates a new data + descriptor inside the parent SDFG, with same shape and strides. + """ + inner_desc = self.dc_node.desc(inner_sdfg) + assert inner_desc.transient + inner_desc.transient = False + + if isinstance(self.gt_type, ts.ScalarType): + outer, outer_desc = sdfg_builder.add_temp_scalar(outer_sdfg, inner_desc.dtype) + outer_origin = [] + else: + outer, outer_desc = sdfg_builder.add_temp_array_like(outer_sdfg, inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + symbol_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + # Same applies to the symbols used as field origin (the domain range start) + outer_origin = [ + gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) for val in self.origin + ] + + outer_node = outer_sdfg_state.add_access(outer) + return FieldopData(outer_node, self.gt_type, tuple(outer_origin)) def get_local_view( self, domain: FieldopDomain @@ -97,18 +134,20 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) + dc_node=self.dc_node, + gt_dtype=self.gt_type, + subset=dace_subsets.Range.from_string("0"), ) if isinstance(self.gt_type, ts.FieldType): domain_dims = [dim for dim, _, _ in domain] - domain_indices = get_domain_indices(domain_dims) + domain_indices = get_domain_indices(domain_dims, origin=None) it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) for dim, index in zip(domain_dims, domain_indices) } - field_domain = [ - (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + field_origin = [ + (dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i]) for i, dim in enumerate(self.gt_type.dims) ] # The property below is ensured by calling `make_field()` to construct `FieldopData`. @@ -116,11 +155,48 @@ def get_local_view( # to `ListType` element type, while the field domain consists of all global dimensions. assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, field_domain, it_indices + self.dc_node, self.gt_type.dtype, field_origin, it_indices ) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") + def get_symbol_mapping( + self, dataname: str, sdfg: dace.SDFG + ) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to create the symbol mapping for array storage in a nested SDFG. + + Args: + dataname: Name of the data container insiode the nested SDFG. + sdfg: The parent SDFG where the `FieldopData` object lives. + + Returns: + Mapping from symbols in nested SDFG to the corresponding symbolic values + in the parent SDFG. This includes the range start and stop symbols (used + to calculate the array shape as range 'stop - start') and the strides. + """ + if isinstance(self.gt_type, ts.ScalarType): + return {} + ndims = len(self.gt_type.dims) + outer_desc = self.dc_node.desc(sdfg) + assert isinstance(outer_desc, dace.data.Array) + # origin and size of the local dimension, in case of a field with `ListType` data, + # are assumed to be compiled-time values (not symbolic), therefore the start and + # stop range symbols of the inner field only extend over the global dimensions + return ( + {gtx_dace_utils.range_start_symbol(dataname, i): (self.origin[i]) for i in range(ndims)} + | { + gtx_dace_utils.range_stop_symbol(dataname, i): ( + self.origin[i] + outer_desc.shape[i] + ) + for i in range(ndims) + } + | { + gtx_dace_utils.field_stride_symbol_name(dataname, i): stride + for i, stride in enumerate(outer_desc.strides) + } + ) + FieldopDomain: TypeAlias = list[ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] @@ -141,6 +217,33 @@ def get_local_view( """Data type used for field indexing.""" +def get_arg_symbol_mapping( + dataname: str, arg: FieldopResult, sdfg: dace.SDFG +) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to build the mapping from inner to outer SDFG of all symbols + used for storage of a field or a tuple of fields. + + Args: + dataname: The storage name inside the nested SDFG. + arg: The argument field in the parent SDFG. + sdfg: The parent SDFG where the argument field lives. + + Returns: + A mapping from inner symbol names to values or symbolic definitions + in the parent SDFG. + """ + if isinstance(arg, FieldopData): + return arg.get_symbol_mapping(dataname, sdfg) + + symbol_mapping: dict[str, dace.symbolic.SymExpr] = {} + for i, elem in enumerate(arg): + dataname_elem = f"{dataname}_{i}" + symbol_mapping |= get_arg_symbol_mapping(dataname_elem, elem, sdfg) + + return symbol_mapping + + def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: """ Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. @@ -239,7 +342,7 @@ def get_field_layout( Returns: A tuple of three lists containing: - the domain dimensions - - the domain offset in each dimension + - the domain origin, that is the start indices in all dimensions - the domain size in each dimension """ domain_dims, domain_lbs, domain_ubs = zip(*domain) @@ -278,9 +381,9 @@ def _create_field_operator_impl( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) # the memory layout of the output field follows the field operator compute domain - domain_dims, domain_offset, domain_shape = get_field_layout(domain) - domain_indices = get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) + field_dims, field_origin, field_shape = get_field_layout(domain) + field_indices = get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): if output_edge.result.gt_dtype != output_type.dtype: @@ -288,8 +391,6 @@ def _create_field_operator_impl( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) assert isinstance(dataflow_output_desc, dace.data.Scalar) - field_shape = domain_shape - field_subset = domain_subset else: assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) @@ -301,8 +402,8 @@ def _create_field_operator_impl( assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None - field_shape = [*domain_shape, dataflow_output_desc.shape[0]] - field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) + field_shape = [*field_shape, dataflow_output_desc.shape[0]] + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) @@ -312,9 +413,7 @@ def _create_field_operator_impl( output_edge.connect(map_exit, field_node, field_subset) return FieldopData( - field_node, - ts.FieldType(domain_dims, output_edge.result.gt_dtype), - offset=(domain_offset if set(domain_offset) != {0} else None), + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) ) @@ -535,7 +634,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc) outer_node = state.add_access(outer) - return inner_data.make_copy(outer_node) + return FieldopData(outer_node, inner_data.gt_type, inner_data.origin) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -696,7 +795,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type, offset=None) + return FieldopData(data_node, data_type, origin=()) def translate_make_tuple( @@ -818,7 +917,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type, offset=None) + return FieldopData(temp_node, node.type, origin=()) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 59d1a0087a..584ce849e1 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -807,6 +807,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) input_memlets: dict[str, MemletExpr | ValueExpr] = {} + nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None # define scalar or symbol for the condition value inside the nested SDFG if isinstance(condition_value, SymbolExpr): @@ -845,12 +846,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + # all free symbols are mapped to the symbols available in parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + if isinstance(condition_value, SymbolExpr): + nsdfg_symbols_mapping["__cond"] = condition_value.value nsdfg_node = self.state.add_nested_sdfg( nsdfg, self.sdfg, inputs=set(input_memlets.keys()), outputs=outputs, - symbol_mapping=None, # implicitly map all free symbols to the symbols available in parent SDFG + symbol_mapping=nsdfg_symbols_mapping, ) for inner, input_expr in input_memlets.items(): @@ -1504,7 +1509,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 56a67510e7..763c292836 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -74,19 +74,16 @@ } -def builtin_cast(*args: Any) -> str: - val, target_type = args +def builtin_cast(val: str, target_type: str) -> str: assert target_type in builtins.TYPE_BUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) -def builtin_if(*args: Any) -> str: - cond, true_val, false_val = args +def builtin_if(cond: str, true_val: str, false_val: str) -> str: return f"{true_val} if {cond} else {false_val}" -def builtin_tuple_get(*args: Any) -> str: - index, tuple_name = args +def builtin_tuple_get(index: str, tuple_name: str) -> str: return f"{tuple_name}_{index}" @@ -99,7 +96,7 @@ def make_const_list(arg: str) -> str: return arg -GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { +GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = { "cast_": builtin_cast, "if_": builtin_if, "make_const_list": make_const_list, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index ec88cd8f84..791440c37a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -22,7 +22,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional +import itertools +from typing import TYPE_CHECKING, Any, Iterable import dace from dace import subsets as dace_subsets @@ -108,19 +109,19 @@ def _create_scan_field_operator_impl( assert isinstance(dataflow_output_desc, dace.data.Array) # the memory layout of the output field follows the field operator compute domain - domain_dims, domain_offset, domain_shape = gtir_translators.get_field_layout(domain) - domain_indices = gtir_translators.get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) + field_dims, field_origin, field_shape = gtir_translators.get_field_layout(domain) + field_indices = gtir_translators.get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) # the vertical dimension used as scan column is computed by the `LoopRegion` # inside the map scope, therefore it is excluded from the map range - scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in domain_dims].index(True) + scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in field_dims].index(True) # the map scope writes the full-shape dimension corresponding to the scan column field_subset = ( - dace_subsets.Range(domain_subset[:scan_dim_index]) + dace_subsets.Range(field_subset[:scan_dim_index]) + dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}") - + dace_subsets.Range(domain_subset[scan_dim_index + 1 :]) + + dace_subsets.Range(field_subset[scan_dim_index + 1 :]) ) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): @@ -130,7 +131,6 @@ def _create_scan_field_operator_impl( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) field_dtype = output_edge.result.gt_dtype - field_shape = domain_shape # the scan field operator computes a column of scalar values assert len(dataflow_output_desc.shape) == 1 else: @@ -146,7 +146,7 @@ def _create_scan_field_operator_impl( assert len(dataflow_output_desc.shape) == 2 # the lines below extend the array with the local dimension added by the field operator assert output_edge.result.gt_dtype.offset_type is not None - field_shape = [*domain_shape, dataflow_output_desc.shape[1]] + field_shape = [*field_shape, dataflow_output_desc.shape[1]] field_subset = field_subset + dace_subsets.Range.from_string( f"0:{dataflow_output_desc.shape[1]}" ) @@ -158,7 +158,7 @@ def _create_scan_field_operator_impl( # the inner and outer strides have to match scan_output_stride = field_desc.strides[scan_dim_index] # also consider the stride of the local dimension, in case the scan field operator computes a list - local_strides = field_desc.strides[len(domain_dims) :] + local_strides = field_desc.strides[len(field_dims) :] assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0) new_inner_strides = [scan_output_stride, *local_strides] dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides) @@ -168,9 +168,7 @@ def _create_scan_field_operator_impl( output_edge.connect(map_exit, field_node, field_subset) return gtir_translators.FieldopData( - field_node, - ts.FieldType(domain_dims, output_edge.result.gt_dtype), - offset=(domain_offset if set(domain_offset) != {0} else None), + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) ) @@ -271,7 +269,6 @@ def _lower_lambda_to_nested_sdfg( domain: gtir_translators.FieldopDomain, init_data: gtir_translators.FieldopResult, lambda_symbols: dict[str, ts.DataType], - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], scan_forward: bool, scan_carry_symbol: gtir.Sym, ) -> tuple[dace.SDFG, gtir_translators.FieldopResult]: @@ -297,8 +294,6 @@ def _lower_lambda_to_nested_sdfg( init_data: The data produced in the field operator context that is used to initialize the scan carry value. lambda_symbols: List of symbols used as parameters of the stencil expressions. - lambda_field_offsets: Mapping from symbol name to field origin, - `None` if field origin is 0 in all dimensions. scan_forward: When True, the loop should range starting from the origin; when False, traverse towards origin. scan_carry_symbol: The symbol used in the stencil expression to carry the @@ -317,9 +312,7 @@ def _lower_lambda_to_nested_sdfg( # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo) - lambda_translator = sdfg_builder.setup_nested_context( - lambda_node, nsdfg, lambda_symbols, lambda_field_offsets - ) + lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols) # use the vertical dimension in the domain as scan dimension scan_domain = [ @@ -474,7 +467,7 @@ def connect_scan_output( ) output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) - return gtir_translators.FieldopData(output_node, output_type, offset=scan_lower_bound) + return gtir_translators.FieldopData(output_node, output_type, origin=(scan_lower_bound,)) # write the stencil result (value on one vertical level) into a 1D field # with full vertical shape representing one column @@ -603,24 +596,36 @@ def translate_scan( for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True) } + # lower the scan stencil expression in a separate SDFG context + nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( + stencil_expr, + sdfg, + sdfg_builder, + domain, + init_data, + lambda_symbols, + scan_forward, + im.sym(scan_carry, scan_carry_type), + ) + # visit the arguments to be passed to the lambda expression # this must be executed before visiting the lambda expression, in order to populate # the data descriptor with the correct field domain offsets for field arguments lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] - lambda_args_mapping = { - _scan_input_name(scan_carry): init_data, - } | { - str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) - } + lambda_args_mapping = [ + (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), + ] + [ + (im.sym(param.id, arg.gt_type), arg) + for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + ] + + lambda_arg_nodes = dict( + itertools.chain( + *[gtir_translators.flatten_tuples(psym.id, arg) for psym, arg in lambda_args_mapping] + ) + ) - # parse the dataflow input and output symbols - lambda_flat_args: dict[str, gtir_translators.FieldopData] = {} - # the field offset is set to `None` when it is zero in all dimensions - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} - for param, outer_arg in lambda_args_mapping.items(): - tuple_fields = gtir_translators.flatten_tuples(param, outer_arg) - lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} - lambda_flat_args |= dict(tuple_fields) + # parse the dataflow output symbols if isinstance(scan_carry_type, ts.TupleType): lambda_flat_outs = { str(sym.id): sym.type @@ -631,46 +636,23 @@ def translate_scan( else: lambda_flat_outs = {_scan_output_name(scan_carry): scan_carry_type} - # lower the scan stencil expression in a separate SDFG context - nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( - stencil_expr, - sdfg, - sdfg_builder, - domain, - init_data, - lambda_symbols, - lambda_field_offsets, - scan_forward, - im.sym(scan_carry, scan_carry_type), - ) - # build the mapping of symbols from nested SDFG to field operator context nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} - for inner_dataname, outer_arg in lambda_flat_args.items(): - inner_desc = nsdfg.data(inner_dataname) - outer_desc = outer_arg.dc_node.desc(sdfg) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*inner_desc.shape, *inner_desc.strides], - [*outer_desc.shape, *outer_desc.strides], - strict=True, - ) - if dace.symbolic.issymbolic(nested_symbol) - } + for psym, arg in lambda_args_mapping: + nsdfg_symbols_mapping |= gtir_translators.get_arg_symbol_mapping(psym.id, arg, sdfg) # the scan nested SDFG is ready: it is instantiated in the field operator context # where the map scope over the horizontal domain lives nsdfg_node = state.add_nested_sdfg( nsdfg, sdfg, - inputs=set(lambda_flat_args.keys()), + inputs=set(lambda_arg_nodes.keys()), outputs=set(lambda_flat_outs.keys()), symbol_mapping=nsdfg_symbols_mapping, ) lambda_input_edges = [] - for input_connector, outer_arg in lambda_flat_args.items(): + for input_connector, outer_arg in lambda_arg_nodes.items(): arg_desc = outer_arg.dc_node.desc(sdfg) input_subset = dace_subsets.Range.from_array(arg_desc) input_edge = gtir_dataflow.MemletInputEdge( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py index b306a59305..a58e8bcf8a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -16,7 +16,6 @@ import abc import dataclasses -import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -34,6 +33,7 @@ from gt4py.next.program_processors.runners.dace import ( gtir_builtin_translators, gtir_sdfg_utils, + transformations as gtx_transformations, utils as gtx_dace_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -121,7 +121,9 @@ class SDFGBuilder(DataflowBuilder, Protocol): @abc.abstractmethod def make_field( - self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, ) -> gtir_builtin_translators.FieldopData: """Retrieve the field data descriptor including the domain offset information.""" ... @@ -142,7 +144,6 @@ def setup_nested_context( expr: gtir.Expr, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType], - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: """ Create an SDFG context to translate a nested expression, indipendent @@ -156,7 +157,6 @@ def setup_nested_context( expr: The nested expresson to be lowered. sdfg: The SDFG where to lower the nested expression. global_symbols: Mapping from symbol name to GTIR data type. - field_offsets: Mapping from symbol name to field origin, `None` if field origin is 0 in all dimensions. Returns: A visitor object implementing the `SDFGBuilder` protocol. @@ -201,6 +201,24 @@ def _collect_symbols_in_domain_expressions( ) +def _make_access_index_for_field( + domain: gtir_builtin_translators.FieldopDomain, data: gtir_builtin_translators.FieldopData +) -> dace.subsets.Range: + """Helper method to build a memlet subset of a field over the given domain.""" + # convert domain expression to dictionary to ease access to the dimensions, + # since the access indices have to follow the order of dimensions in field domain + if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0: + assert data.origin is not None + domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain} + return dace.subsets.Range( + (domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1) + for dim, origin in zip(data.gt_type.dims, data.origin, strict=True) + ) + else: + assert len(domain) == 0 + return dace.subsets.Range.from_string("0") + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -217,10 +235,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType column_axis: Optional[gtx_common.Dimension] - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( - default_factory=dict - ) + global_symbols: dict[str, ts.DataType] map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -232,18 +247,23 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType return self.offset_provider_type[offset] def make_field( - self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, ) -> gtir_builtin_translators.FieldopData: """ - Helper method to build the field data type associated with an access node in the SDFG. + Helper method to build the field data type associated with a data access node. - In case of `ScalarType` data, the descriptor is constructed with `offset=None`. + In case of `ScalarType` data, the `FieldopData` is constructed with `origin=None`. In case of `FieldType` data, the field origin is added to the data descriptor. Besides, if the `FieldType` contains a local dimension, the descriptor is converted to a canonical form where the field domain consists of all global dimensions (the grid axes) and the field data type is `ListType`, with `offset_type` equal to the field local dimension. + TODO(edoapo): consider refactoring this method and moving it to a type module + close to the `FieldopData` type declaration. + Args: data_node: The access node to the SDFG data storage. data_type: The GT4Py data descriptor, which can either come from a field parameter @@ -253,8 +273,7 @@ def make_field( The descriptor associated with the SDFG data storage, filled with field origin. """ if isinstance(data_type, ts.ScalarType): - return gtir_builtin_translators.FieldopData(data_node, data_type, offset=None) - domain_offset = self.field_offsets.get(data_node.data, None) + return gtir_builtin_translators.FieldopData(data_node, data_type, origin=()) local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] if len(local_dims) == 0: # do nothing: the field domain consists of all global dimensions @@ -279,7 +298,11 @@ def make_field( raise NotImplementedError( "Fields with more than one local dimension are not supported." ) - return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset) + field_origin = tuple( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.range_start_symbol(data_node.data, axis)) + for axis in range(len(field_type.dims)) + ) + return gtir_builtin_translators.FieldopData(data_node, field_type, field_origin) def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -293,11 +316,8 @@ def setup_nested_context( expr: gtir.Expr, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType], - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: - nsdfg_builder = GTIRToSDFG( - self.offset_provider_type, self.column_axis, global_symbols, field_offsets - ) + nsdfg_builder = GTIRToSDFG(self.offset_provider_type, self.column_axis, global_symbols) nsdfg_params = [ gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() ] @@ -321,28 +341,45 @@ def unique_tasklet_name(self, name: str) -> str: def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] - ) -> tuple[list[dace.symbol], list[dace.symbol]]: + ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: """ Parse field dimensions and allocate symbols for array shape and strides. For local dimensions, the size is known at compile-time and therefore the corresponding array shape dimension is set to an integer literal value. + This method is only called for non-transient arrays, which require symbolic + memory layout. The memory layout of transient arrays, used for temporary + fields, is left to the DaCe default (row major, not necessarily the optimal + one) and might be changed during optimization. + Returns: Two lists of symbols, one for the shape and the other for the strides of the array. """ - dc_dtype = gtir_builtin_translators.INDEX_DTYPE neighbor_table_types = gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) - shape = [ - ( - neighbor_table_types[dim.value].max_neighbors - if dim.kind == gtx_common.DimensionKind.LOCAL - else dace.symbol(gtx_dace_utils.field_size_symbol_name(name, i), dc_dtype) - ) - for i, dim in enumerate(dims) - ] + shape = [] + for i, dim in enumerate(dims): + if dim.kind == gtx_common.DimensionKind.LOCAL: + # for local dimension, the size is taken from the associated connectivity type + shape.append(neighbor_table_types[dim.value].max_neighbors) + elif gtx_dace_utils.is_connectivity_identifier(name, self.offset_provider_type): + # we use symbolic size for the global dimension of a connectivity + shape.append( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i)) + ) + else: + # the size of global dimensions for a regular field is the symbolic + # expression of domain range 'stop - start' + shape.append( + dace.symbolic.pystr_to_symbolic( + "{} - {}".format( + gtx_dace_utils.range_stop_symbol(name, i), + gtx_dace_utils.range_start_symbol(name, i), + ) + ) + ) strides = [ - dace.symbol(gtx_dace_utils.field_stride_symbol_name(name, i), dc_dtype) + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_stride_symbol_name(name, i)) for i in range(len(dims)) ] return shape, strides @@ -470,7 +507,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return field.make_copy(temp_node) + return gtir_builtin_translators.FieldopData(temp_node, field.gt_type, field.origin) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -498,7 +535,6 @@ def _add_sdfg_params( sdfg_args += self._add_storage( sdfg, symbolic_arguments, pname, param.type, transient=False ) - self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables for offset, connectivity_type in gtx_dace_utils.filter_connectivity_types( @@ -532,13 +568,6 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: The temporary data is global, therefore available everywhere in the SDFG but not outside. Then, all statements are translated, one after the other. """ - if node.function_definitions: - raise NotImplementedError("Functions expected to be inlined as lambda calls.") - - # Since program field arguments are passed to the SDFG as full-shape arrays, - # there is no offset that needs to be compensated. - assert len(self.field_offsets) == 0 - sdfg = dace.SDFG(node.id) sdfg.debuginfo = gtir_sdfg_utils.debug_info(node) @@ -605,10 +634,8 @@ def visit_SetAt( # in case the statement returns more than one field target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) - # convert domain expression to dictionary to ease access to dimension boundaries - domain = { - dim: (lb, ub) for dim, lb, ub in gtir_builtin_translators.extract_domain(stmt.domain) - } + # visit the domain expression + domain = gtir_builtin_translators.extract_domain(stmt.domain) expr_input_args = { sym_id @@ -626,22 +653,9 @@ def visit_SetAt( target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient - if isinstance(target.gt_type, ts.FieldType): - target_subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims - ) - source_subset = ( - target_subset - if source.offset is None - else ",".join( - f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" - for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) - ) - ) - else: - assert len(domain) == 0 - target_subset = "0" - source_subset = "0" + assert source.gt_type == target.gt_type + source_subset = _make_access_index_for_field(domain, source) + target_subset = _make_access_index_for_field(domain, target) if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -725,15 +739,11 @@ def visit_Lambda( i.e. a lambda parameter with the same name as a symbol in scope, the parameter will shadow the previous symbol during traversal of the lambda expression. """ - lambda_args_mapping = [ - (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) - ] - lambda_arg_nodes = dict( itertools.chain( *[ - gtir_builtin_translators.flatten_tuples(pname, arg) - for pname, arg in lambda_args_mapping + gtir_builtin_translators.flatten_tuples(psym.id, arg) + for psym, arg in zip(node.params, args, strict=True) ] ) ) @@ -743,44 +753,16 @@ def visit_Lambda( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: gtir_builtin_translators.get_tuple_type(arg) + psym.id: gtir_builtin_translators.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type - for pname, arg in lambda_args_mapping + for psym, arg in zip(node.params, args, strict=True) } - def get_field_domain_offset( - p_name: str, p_type: ts.DataType - ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: - if isinstance(p_type, ts.FieldType): - if p_name in lambda_arg_nodes: - arg = lambda_arg_nodes[p_name] - assert isinstance(arg, gtir_builtin_translators.FieldopData) - return {p_name: arg.offset} - elif field_domain_offset := self.field_offsets.get(p_name, None): - return {p_name: field_domain_offset} - elif isinstance(p_type, ts.TupleType): - tsyms = gtir_sdfg_utils.flatten_tuple_fields(p_name, p_type) - return functools.reduce( - lambda field_offsets, sym: ( - field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] - ), - tsyms, - {}, - ) - return {} - - # populate mapping from field name to domain offset - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} - for p_name, p_type in lambda_symbols.items(): - lambda_field_offsets |= get_field_domain_offset(p_name, p_type) - # lower let-statement lambda node as a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo) - lambda_translator = self.setup_nested_context( - node.expr, nsdfg, lambda_symbols, lambda_field_offsets - ) + lambda_translator = self.setup_nested_context(node.expr, nsdfg, lambda_symbols) nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( @@ -800,7 +782,6 @@ def get_field_domain_offset( } input_memlets = {} - nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): if nsdfg_datadesc.transient: continue @@ -809,15 +790,6 @@ def get_field_domain_offset( src_node = lambda_arg_nodes[nsdfg_dataname].dc_node dataname = src_node.data datadesc = src_node.desc(sdfg) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], - [*datadesc.shape, *datadesc.strides], - strict=True, - ) - if dace.symbolic.issymbolic(nested_symbol) - } else: dataname = nsdfg_dataname datadesc = sdfg.arrays[nsdfg_dataname] @@ -855,6 +827,13 @@ def get_field_domain_offset( if output_data.dc_node.desc(nsdfg).transient } + # map free symbols to parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + for sym, arg in zip(node.params, args, strict=True): + nsdfg_symbols_mapping |= gtir_builtin_translators.get_arg_symbol_mapping( + sym.id, arg, sdfg + ) + nsdfg_node = head_state.add_nested_sdfg( nsdfg, parent=sdfg, @@ -888,33 +867,34 @@ def construct_output_for_nested_sdfg( arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ inner_desc = inner_data.dc_node.desc(nsdfg) + inner_dataname = inner_data.dc_node.data if inner_desc.transient: # Transient data nodes only exist within the nested SDFG. In order to return some result data, # the corresponding data container inside the nested SDFG has to be changed to non-transient, # that is externally allocated, as required by the SDFG IR. An output edge will write the result # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. - inner_desc.transient = False - outer, outer_desc = self.add_temp_array_like(sdfg, inner_desc) - # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. - dace.symbolic.safe_replace( - nsdfg_symbols_mapping, - lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + outer_data = inner_data.map_to_parent_sdfg( + self, nsdfg, sdfg, head_state, nsdfg_symbols_mapping ) - connector = inner_data.dc_node.data - outer_node = head_state.add_access(outer) head_state.add_edge( - nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) + nsdfg_node, + inner_dataname, + outer_data.dc_node, + None, + sdfg.make_array_memlet(outer_data.dc_node.data), ) - outer_data = inner_data.make_copy(outer_node) - elif inner_data.dc_node.data in lambda_arg_nodes: + elif inner_dataname in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - outer_data = lambda_arg_nodes[inner_data.dc_node.data] + outer_data = lambda_arg_nodes[inner_dataname] else: - outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = inner_data.make_copy(outer_node) + # This must be a symbol captured from the lambda parent scope. + outer_node = head_state.add_access(inner_dataname) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_type, inner_data.origin + ) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. @@ -941,10 +921,50 @@ def visit_SymRef( return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) +def _remove_field_origin_symbols(ir: gtir.Program, sdfg: dace.SDFG) -> None: + """ + Helper function to remove the origin symbols used in program field arguments, + that is only for non-transient data descriptors in the top-level SDFG. + The start symbol of field domain range is set to constant value 0, thus removing + the corresponding free symbol. These values are propagated to all nested SDFGs. + + This function is only used by `build_sdfg_from_gtir()` when the option flag + `disable_field_origin_on_program_arguments` is set to True. + """ + + # collect symbols used as range start for all program arguments + range_start_symbols: dict[str, dace.symbolic.SymExpr] = {} + for p in ir.params: + if isinstance(p.type, ts.TupleType): + psymbols = [ + sym + for sym in gtir_sdfg_utils.flatten_tuple_fields(p.id, p.type) + if isinstance(sym.type, ts.FieldType) + ] + elif isinstance(p.type, ts.FieldType): + psymbols = [p] + else: + psymbols = [] + for psymbol in psymbols: + assert isinstance(psymbol.type, ts.FieldType) + if len(psymbol.type.dims) == 0: + # zero-dimensional field + continue + dataname = str(psymbol.id) + # set all range start symbols to constant value 0 + range_start_symbols |= { + gtx_dace_utils.range_start_symbol(dataname, i): 0 + for i in range(len(psymbol.type.dims)) + } + # we set all range start symbols to 0 in the top-level SDFG and proagate them to nested SDFGs + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, range_start_symbols, validate=True) + + def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, column_axis: Optional[gtx_common.Dimension] = None, + disable_field_origin_on_program_arguments: bool = False, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -956,11 +976,15 @@ def build_sdfg_from_gtir( ir: The GTIR program node to be lowered to SDFG offset_provider_type: The definitions of offset providers used by the program node column_axis: Vertical dimension used for column scan expressions. + disable_field_origin_on_program_arguments: When True, the field range in all dimensions is assumed to start from 0 Returns: An SDFG in the DaCe canonical form (simplified) """ + if ir.function_definitions: + raise NotImplementedError("Functions expected to be inlined as lambda calls.") + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) @@ -970,11 +994,15 @@ def build_sdfg_from_gtir( # Here we find new names for invalid symbols present in the IR. ir = gtir_sdfg_utils.replace_invalid_symbols(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis) + global_symbols = {str(p.id): p.type for p in ir.params if isinstance(p.type, ts.DataType)} + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis, global_symbols) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct dace_sdfg_utils.inline_loop_blocks(sdfg) + if disable_field_origin_on_program_arguments: + _remove_field_origin_symbols(ir, sdfg) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 78016db0a9..a381346a1e 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -85,11 +85,13 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: if not hasattr(self.backend.executor, "step") else self.backend.executor.step, ) # We know which backend we are using, but we don't know if the compile workflow is cached. - # TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously + # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. sdfg = dace.SDFG.from_json( - compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code + compile_workflow.translation.replace( + disable_itir_transforms=True, disable_field_origin_on_program_arguments=True + )(gtir_stage).source_code ) self.sdfg_closure_cache["arrays"] = sdfg.arrays diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index 7f221a5a41..09720ddf3c 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional import dace import numpy as np @@ -24,32 +24,40 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str) -> Any: +def _convert_arg(arg: Any) -> tuple[Any, Optional[gtx_common.Domain]]: if not isinstance(arg, gtx_common.Field): - return arg + return arg, None if len(arg.domain.dims) == 0: # Pass zero-dimensional fields as scalars. - return arg.as_scalar() - # field domain offsets are not supported - non_zero_offsets = [ - (dim, dim_range) - for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges, strict=True) - if dim_range.start != 0 - ] - if non_zero_offsets: - dim, dim_range = non_zero_offsets[0] - raise RuntimeError( - f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." - ) - return arg.ndarray + return arg.as_scalar(), None + return arg.ndarray, arg.domain def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names - return { - sdfg_param: _convert_arg(arg, sdfg_param) - for sdfg_param, arg in zip(sdfg_params, args, strict=True) - } + sdfg_arguments = {} + range_symbols: dict[str, int] = {} + for sdfg_param, arg in zip(sdfg_params, args, strict=True): + sdfg_arg, domain = _convert_arg(arg) + sdfg_arguments[sdfg_param] = sdfg_arg + if domain: + assert gtx_common.Domain.is_finite(domain) + range_symbols |= { + gtx_dace_utils.range_start_symbol(sdfg_param, i): r.start + for i, r in enumerate(domain.ranges) + } + range_symbols |= { + gtx_dace_utils.range_stop_symbol(sdfg_param, i): r.stop + for i, r in enumerate(domain.ranges) + } + # sanity check in case range symbols are passed as explicit program arguments + for range_symbol, value in range_symbols.items(): + if (sdfg_arg := sdfg_arguments.get(range_symbol, None)) is not None: + if sdfg_arg != value: + raise ValueError( + f"Received program argument {range_symbol} with value {sdfg_arg}, expected {value}." + ) + return sdfg_arguments | range_symbols def _ensure_is_on_device( @@ -150,18 +158,16 @@ def get_sdfg_args( dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_field_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) - dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = _get_stride_args(sdfg.arrays, dace_conn_args) all_args = { **dace_args, **dace_conn_args, - **dace_shapes, **dace_conn_shapes, - **dace_strides, **dace_conn_strides, + **dace_field_strides, } if check_args: diff --git a/src/gt4py/next/program_processors/runners/dace/utils.py b/src/gt4py/next/program_processors/runners/dace/utils.py index cca0c001e7..5fdace73a9 100644 --- a/src/gt4py/next/program_processors/runners/dace/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/utils.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal +from typing import Final, Literal, Mapping, Union import dace @@ -73,6 +73,16 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str: return field_symbol_name(field_name, axis, "stride") +def range_start_symbol(field_name: str, axis: int) -> str: + """Format name of start symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_0" + + +def range_stop_symbol(field_name: str, axis: int) -> str: + """Format name of stop symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_1" + + def is_field_symbol(name: str) -> bool: return FIELD_SYMBOL_RE.match(name) is not None @@ -90,3 +100,29 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } + + +def safe_replace_symbolic( + val: dace.symbolic.SymbolicType, + symbol_mapping: Mapping[ + Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str] + ], +) -> dace.symbolic.SymbolicType: + """ + Replace free symbols in a dace symbolic expression, using `safe_replace()` + in order to avoid clashes in case the new symbol value is also a free symbol + in the original exoression. + + Args: + val: The symbolic expression where to apply the replacement. + symbol_mapping: The mapping table for symbol replacement. + + Returns: + A new symbolic expression as result of symbol replacement. + """ + # The list `x` is needed because `subs()` returns a new object and can not handle + # replacement dicts of the form `{'x': 'y', 'y': 'x'}`. + # The utility `safe_replace()` will call `subs()` twice in case of such dicts. + x = [val] + dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m))) + return x[-1] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 96be93de5e..6e1b3a6f32 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -35,7 +35,8 @@ class DaCeTranslator( ): device_type: core_defs.DeviceType auto_optimize: bool - itir_transforms_off: bool = False + disable_itir_transforms: bool = False + disable_field_origin_on_program_arguments: bool = False def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -50,10 +51,13 @@ def generate_sdfg( auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: - if not self.itir_transforms_off: + if not self.disable_itir_transforms: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( - ir, common.offset_provider_to_type(offset_provider), column_axis + ir, + common.offset_provider_to_type(offset_provider), + column_axis, + disable_field_origin_on_program_arguments=self.disable_field_origin_on_program_arguments, ) if auto_opt: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 522250cafc..a96d967430 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -152,7 +152,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), (USES_COMPOSITE_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 8fe0634302..3ba376b08f 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -37,9 +37,6 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() @@ -62,7 +59,9 @@ def sdfg(): tmp_field, out_field ) - sdfg() + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + sdfg() assert np.allclose( gtx.field_utils.asnumpy(out_field)[2:-2, 2:-2], @@ -85,9 +84,6 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - allocator, backend = unstructured_case.allocator, unstructured_case.backend if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): @@ -139,16 +135,18 @@ def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> in # DaCe strides: number of elements to jump return arg.strides[axis] // arg.itemsize - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), - ) + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) e2v_np = e2v.asnumpy() assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) @@ -160,8 +158,8 @@ def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> in allocator=allocator, ) offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): cSDFG( a, out, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 4edaf9f85f..e5e2f18608 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -77,13 +77,9 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 ) -@pytest.mark.skipif(dace is None, reason="DaCe not found") def test_halo_exchange_helper_attrs(unstructured): local_int = gtx.int - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - @gtx.field_operator(backend=unstructured.backend) def testee_op( a: gtx.Field[[Vertex, KDim], gtx.int], diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index e44e92013f..1726956332 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -61,6 +61,7 @@ def hdiff(inp, coeff, out, x, y): set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) +@pytest.mark.uses_lift @pytest.mark.uses_origin def test_hdiff(hdiff_reference, program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py new file mode 100644 index 0000000000..eec68a6486 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py @@ -0,0 +1,21 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Test utility functions of the dace backend module.""" + +import pytest + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils + + +def test_safe_replace_symbolic(): + assert gtx_dace_utils.safe_replace_symbolic( + dace.symbolic.pystr_to_symbolic("x*x + y"), symbol_mapping={"x": "y", "y": "x"} + ) == dace.symbolic.pystr_to_symbolic("y*y + x") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 7431ad2b4a..8ebb240339 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -13,6 +13,7 @@ """ import functools +from typing import Any, Callable import numpy as np import pytest @@ -52,13 +53,13 @@ SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( - __w_size_0=N, + __w_0_range_1=N, __w_stride_0=1, - __x_size_0=N, + __x_0_range_1=N, __x_stride_0=1, - __y_size_0=N, + __y_0_range_1=N, __y_stride_0=1, - __z_size_0=N, + __z_0_range_1=N, __z_stride_0=1, size=N, ) @@ -69,31 +70,39 @@ def make_mesh_symbols(mesh: MeshDescriptor): ncells=mesh.num_cells, nedges=mesh.num_edges, nvertices=mesh.num_vertices, - __cells_size_0=mesh.num_cells, + __cells_0_range_1=mesh.num_cells, __cells_stride_0=1, - __edges_size_0=mesh.num_edges, + __edges_0_range_1=mesh.num_edges, __edges_stride_0=1, - __vertices_size_0=mesh.num_vertices, + __vertices_0_range_1=mesh.num_vertices, __vertices_stride_0=1, - __connectivity_C2E_size_0=mesh.num_cells, + __connectivity_C2E_0_range_1=mesh.num_cells, __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=mesh.num_cells, + __connectivity_C2V_0_range_1=mesh.num_cells, __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=mesh.num_edges, + __connectivity_E2V_0_range_1=mesh.num_edges, __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=mesh.num_vertices, + __connectivity_V2E_0_range_1=mesh.num_vertices, __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) +def build_dace_sdfg( + ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType +) -> Callable[..., Any]: + return dace_backend.build_sdfg_from_gtir( + ir, offset_provider_type, disable_field_origin_on_program_arguments=True + ) + + def test_gtir_broadcast(): val = np.random.rand() domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) @@ -116,7 +125,7 @@ def test_gtir_broadcast(): a = np.empty(N, dtype=np.float64) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) np.testing.assert_array_equal(a, val) @@ -152,7 +161,7 @@ def test_gtir_cast(): b = a.astype(np.float32) c = np.empty_like(a, dtype=np.bool_) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) np.testing.assert_array_equal(c, True) @@ -180,7 +189,7 @@ def test_gtir_copy_self(): a = np.random.rand(N) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) assert np.allclose(a, ref) @@ -211,7 +220,7 @@ def test_gtir_tuple_swap(): b = np.random.rand(N) ref = (a.copy(), b.copy()) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref[1]) @@ -250,16 +259,16 @@ def test_gtir_tuple_args(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) x_fields = (a, a, b) tuple_symbols = { - "__x_0_size_0": N, + "__x_0_0_range_1": N, "__x_0_stride_0": 1, - "__x_1_0_size_0": N, + "__x_1_0_0_range_1": N, "__x_1_0_stride_0": 1, - "__x_1_1_size_0": N, + "__x_1_1_0_range_1": N, "__x_1_1_stride_0": 1, } @@ -302,7 +311,7 @@ def test_gtir_tuple_expr(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, a * 2 + b) @@ -356,7 +365,7 @@ def test_gtir_tuple_broadcast_scalar(): c = np.random.rand() d = np.empty(N, dtype=type(a)) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) x_fields = (a, b, c) @@ -387,7 +396,7 @@ def test_gtir_zero_dim_fields(): a = np.asarray(np.random.rand()) b = np.empty(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a.item(), b, **FSYMBOLS) assert np.allclose(a, b) @@ -421,16 +430,16 @@ def test_gtir_tuple_return(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) tuple_symbols = { - "__z_0_0_size_0": N, + "__z_0_0_0_range_1": N, "__z_0_0_stride_0": 1, - "__z_0_1_size_0": N, + "__z_0_1_0_range_1": N, "__z_0_1_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -464,7 +473,7 @@ def test_gtir_tuple_target(): b = np.empty_like(a) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref + 1) @@ -496,7 +505,7 @@ def test_gtir_update(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) ref = a - 1.0 @@ -530,7 +539,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -559,7 +568,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, (a + a)) @@ -601,7 +610,7 @@ def test_gtir_sum3(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) d = np.empty_like(a) @@ -645,7 +654,7 @@ def test_gtir_cond(): b = np.random.rand(N) c = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) for s1, s2 in [(1, 2), (2, 1)]: d = np.empty_like(a) @@ -687,12 +696,12 @@ def test_gtir_cond_with_tuple_return(): b = np.random.rand(N) c = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -735,7 +744,7 @@ def test_gtir_cond_nested(): a = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) for s1 in [False, True]: for s2 in [False, True]: @@ -841,9 +850,9 @@ def test_gtir_cartesian_shift_left(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[OFFSET:] + DELTA, b[:-OFFSET]) @@ -936,9 +945,9 @@ def test_gtir_cartesian_shift_right(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[:-OFFSET] + DELTA, b[OFFSET:]) @@ -1075,7 +1084,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1088,17 +1097,17 @@ def test_gtir_connectivity_shift(): connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __ce_field_size_0=SIMPLE_MESH.num_cells, + __ce_field_0_range_1=SIMPLE_MESH.num_cells, __ce_field_size_1=SIMPLE_MESH.num_edges, __ce_field_stride_0=SIMPLE_MESH.num_edges, __ce_field_stride_1=1, - __ev_field_size_0=SIMPLE_MESH.num_edges, + __ev_field_0_range_1=SIMPLE_MESH.num_edges, __ev_field_size_1=SIMPLE_MESH.num_vertices, __ev_field_stride_0=SIMPLE_MESH.num_vertices, __ev_field_stride_1=1, - __c2e_offset_size_0=SIMPLE_MESH.num_cells, + __c2e_offset_0_range_1=SIMPLE_MESH.num_cells, __c2e_offset_stride_0=1, - __e2v_offset_size_0=SIMPLE_MESH.num_edges, + __e2v_offset_0_range_1=SIMPLE_MESH.num_edges, __e2v_offset_stride_0=1, ) assert np.allclose(ce, ref) @@ -1136,7 +1145,7 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) @@ -1158,7 +1167,7 @@ def test_gtir_connectivity_shift_chain(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __edges_out_size_0=SIMPLE_MESH.num_edges, + __edges_out_0_range_1=SIMPLE_MESH.num_edges, __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) @@ -1196,7 +1205,7 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1217,7 +1226,7 @@ def test_gtir_neighbors_as_input(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1254,7 +1263,7 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1268,7 +1277,7 @@ def test_gtir_neighbors_as_output(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.max_neighbors, __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, @@ -1317,7 +1326,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1377,7 +1386,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1446,7 +1455,7 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, @@ -1454,7 +1463,7 @@ def test_gtir_reduce_dot_product(): v, connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), - __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1503,7 +1512,7 @@ def test_gtir_reduce_with_cond_neighbors(): e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1531,7 +1540,7 @@ def test_gtir_reduce_with_cond_neighbors(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), - __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1618,7 +1627,7 @@ def test_gtir_symbolic_domain(): b = np.random.rand(N) ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, ref) @@ -1666,7 +1675,7 @@ def test_gtir_let_lambda(): b = np.random.rand(N) ref = np.concatenate((b[0:1], a[1 : N - 1] * 8, b[N - 1 : N])) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, ref) @@ -1701,7 +1710,7 @@ def test_gtir_let_lambda_scalar_expression(): c = np.random.rand(N) d = np.empty_like(c) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, c, d, **FSYMBOLS) assert np.allclose(d, (a * a * b * b * c)) @@ -1750,7 +1759,7 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) @@ -1797,7 +1806,7 @@ def test_gtir_let_lambda_with_cond(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) for s in [False, True]: @@ -1835,16 +1844,16 @@ def test_gtir_let_lambda_with_tuple1(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -1884,16 +1893,16 @@ def test_gtir_let_lambda_with_tuple2(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, - "__z_2_size_0": N, + "__z_2_0_range_1": N, "__z_2_stride_0": 1, } @@ -1947,14 +1956,14 @@ def test_gtir_if_scalars(): d1 = np.random.randint(0, 1000) d2 = np.random.randint(0, 1000) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) tuple_symbols = { - "__x_0_size_0": N, + "__x_0_0_range_1": N, "__x_0_stride_0": 1, - "__x_1_0_size_0": N, + "__x_1_0_0_range_1": N, "__x_1_0_stride_0": 1, - "__x_1_1_size_0": N, + "__x_1_1_0_range_1": N, "__x_1_1_stride_0": 1, } @@ -1990,7 +1999,7 @@ def test_gtir_if_values(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) @@ -2032,7 +2041,7 @@ def test_gtir_index(): # we need to run domain inference in order to add the domain annex information to the index node. testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) ref = np.concatenate( (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :])