Skip to content

Commit

Permalink
feat[next][dace]: support for field origin in lowering to SDFG (#1818)
Browse files Browse the repository at this point in the history
This PR adds support for GT4Py field arguments with non-zero start
index, for example:

`inp = constructors.empty(common.domain({IDim: (1, 9)}), ...)`

which was supported in baseline only for temporary fields, by means of a
data structure called `field_offsets`. This data structure is removed
for two reasons:
1. the name "offset" is a left-over from previous design based on dace
array offset
3. offset has a different meaning in GT4Py

We introduce the GT4Py concept of field origin and use it for both
temporary fields and program arguments. The field origin corresponds to
the start of the field domain range.

This PR also changes the symbolic definition of array shape. Before, the
array shape was defined as `[data_size_0, data_size_1, ...]`, now the
size corresponds to the range extent `stop - start` as `[(data_0_range_1
- data_0_range_0), (data_1_range_1 - data_1_range_0), ...]`.

The translation stage of the dace workflow is extended with an option
`disable_field_origin_on_program_arguments` to set the field range start
symbols to constant value zero. This is needed for the dace
orchestration, because the signature of a dace-orchestrated program does
not provide the domain origin.
  • Loading branch information
edopao authored Jan 30, 2025
1 parent 050d3b3 commit f0c67e6
Show file tree
Hide file tree
Showing 15 changed files with 540 additions and 357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,28 @@


def get_domain_indices(
dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None
dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]]
) -> dace_subsets.Indices:
"""
Helper function to construct the list of indices for a field domain, applying
an optional offset in each dimension as start index.
an optional origin in each dimension as start index.
Args:
dims: The field dimensions.
offsets: The range start index in each dimension.
origin: The domain start index in each dimension. If set to `None`, assume all zeros.
Returns:
A list of indices for field access in dace arrays. As this list is returned
as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before
being used in memlet subset because ranges are better supported throughout DaCe.
"""
index_variables = [dace.symbolic.SymExpr(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims]
if offsets is None:
return dace_subsets.Indices(index_variables)
else:
return dace_subsets.Indices(
[
index - offset if offset != 0 else index
for index, offset in zip(index_variables, offsets, strict=True)
]
)
index_variables = [
dace.symbolic.pystr_to_symbolic(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims
]
origin = [0] * len(index_variables) if origin is None else origin
return dace_subsets.Indices(
[index - start_index for index, start_index in zip(index_variables, origin, strict=True)]
)


@dataclasses.dataclass(frozen=True)
Expand All @@ -78,49 +75,128 @@ class FieldopData:
Args:
dc_node: DaCe access node to the data storage.
gt_type: GT4Py type definition, which includes the field domain information.
offset: List of index offsets, in each dimension, when the dimension range
does not start from zero; assume zero offset, if not set.
origin: Tuple of start indices, in each dimension, for `FieldType` data.
Pass an empty tuple for `ScalarType` data or zero-dimensional fields.
"""

dc_node: dace.nodes.AccessNode
gt_type: ts.FieldType | ts.ScalarType
offset: Optional[list[dace.symbolic.SymExpr]]
origin: tuple[dace.symbolic.SymbolicType, ...]

def __post_init__(self) -> None:
"""Implements a sanity check on the constructed data type."""
assert (
len(self.origin) == 0
if isinstance(self.gt_type, ts.ScalarType)
else len(self.origin) == len(self.gt_type.dims)
)

def map_to_parent_sdfg(
self,
sdfg_builder: gtir_sdfg.SDFGBuilder,
inner_sdfg: dace.SDFG,
outer_sdfg: dace.SDFG,
outer_sdfg_state: dace.SDFGState,
symbol_mapping: dict[str, dace.symbolic.SymbolicType],
) -> FieldopData:
"""
Make the data descriptor which 'self' refers to, and which is located inside
a NestedSDFG, available in its parent SDFG.
def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData:
"""Create a copy of this data descriptor with a different access node."""
assert data_node != self.dc_node
return FieldopData(data_node, self.gt_type, self.offset)
Thus, it turns 'self' into a non-transient array and creates a new data
descriptor inside the parent SDFG, with same shape and strides.
"""
inner_desc = self.dc_node.desc(inner_sdfg)
assert inner_desc.transient
inner_desc.transient = False

if isinstance(self.gt_type, ts.ScalarType):
outer, outer_desc = sdfg_builder.add_temp_scalar(outer_sdfg, inner_desc.dtype)
outer_origin = []
else:
outer, outer_desc = sdfg_builder.add_temp_array_like(outer_sdfg, inner_desc)
# We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping.
dace.symbolic.safe_replace(
symbol_mapping,
lambda m: dace.sdfg.replace_properties_dict(outer_desc, m),
)
# Same applies to the symbols used as field origin (the domain range start)
outer_origin = [
gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) for val in self.origin
]

outer_node = outer_sdfg_state.add_access(outer)
return FieldopData(outer_node, self.gt_type, tuple(outer_origin))

def get_local_view(
self, domain: FieldopDomain
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
"""Helper method to access a field in local view, given the compute domain of a field operator."""
if isinstance(self.gt_type, ts.ScalarType):
return gtir_dataflow.MemletExpr(
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0])
dc_node=self.dc_node,
gt_dtype=self.gt_type,
subset=dace_subsets.Range.from_string("0"),
)

if isinstance(self.gt_type, ts.FieldType):
domain_dims = [dim for dim, _, _ in domain]
domain_indices = get_domain_indices(domain_dims)
domain_indices = get_domain_indices(domain_dims, origin=None)
it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = {
dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE)
for dim, index in zip(domain_dims, domain_indices)
}
field_domain = [
(dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i])
field_origin = [
(dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i])
for i, dim in enumerate(self.gt_type.dims)
]
# The property below is ensured by calling `make_field()` to construct `FieldopData`.
# The `make_field` constructor ensures that any local dimension, if present, is converted
# to `ListType` element type, while the field domain consists of all global dimensions.
assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims)
return gtir_dataflow.IteratorExpr(
self.dc_node, self.gt_type.dtype, field_domain, it_indices
self.dc_node, self.gt_type.dtype, field_origin, it_indices
)

raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.")

def get_symbol_mapping(
self, dataname: str, sdfg: dace.SDFG
) -> dict[str, dace.symbolic.SymExpr]:
"""
Helper method to create the symbol mapping for array storage in a nested SDFG.
Args:
dataname: Name of the data container insiode the nested SDFG.
sdfg: The parent SDFG where the `FieldopData` object lives.
Returns:
Mapping from symbols in nested SDFG to the corresponding symbolic values
in the parent SDFG. This includes the range start and stop symbols (used
to calculate the array shape as range 'stop - start') and the strides.
"""
if isinstance(self.gt_type, ts.ScalarType):
return {}
ndims = len(self.gt_type.dims)
outer_desc = self.dc_node.desc(sdfg)
assert isinstance(outer_desc, dace.data.Array)
# origin and size of the local dimension, in case of a field with `ListType` data,
# are assumed to be compiled-time values (not symbolic), therefore the start and
# stop range symbols of the inner field only extend over the global dimensions
return (
{gtx_dace_utils.range_start_symbol(dataname, i): (self.origin[i]) for i in range(ndims)}
| {
gtx_dace_utils.range_stop_symbol(dataname, i): (
self.origin[i] + outer_desc.shape[i]
)
for i in range(ndims)
}
| {
gtx_dace_utils.field_stride_symbol_name(dataname, i): stride
for i, stride in enumerate(outer_desc.strides)
}
)


FieldopDomain: TypeAlias = list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
Expand All @@ -141,6 +217,33 @@ def get_local_view(
"""Data type used for field indexing."""


def get_arg_symbol_mapping(
dataname: str, arg: FieldopResult, sdfg: dace.SDFG
) -> dict[str, dace.symbolic.SymExpr]:
"""
Helper method to build the mapping from inner to outer SDFG of all symbols
used for storage of a field or a tuple of fields.
Args:
dataname: The storage name inside the nested SDFG.
arg: The argument field in the parent SDFG.
sdfg: The parent SDFG where the argument field lives.
Returns:
A mapping from inner symbol names to values or symbolic definitions
in the parent SDFG.
"""
if isinstance(arg, FieldopData):
return arg.get_symbol_mapping(dataname, sdfg)

symbol_mapping: dict[str, dace.symbolic.SymExpr] = {}
for i, elem in enumerate(arg):
dataname_elem = f"{dataname}_{i}"
symbol_mapping |= get_arg_symbol_mapping(dataname_elem, elem, sdfg)

return symbol_mapping


def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType:
"""
Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`.
Expand Down Expand Up @@ -239,7 +342,7 @@ def get_field_layout(
Returns:
A tuple of three lists containing:
- the domain dimensions
- the domain offset in each dimension
- the domain origin, that is the start indices in all dimensions
- the domain size in each dimension
"""
domain_dims, domain_lbs, domain_ubs = zip(*domain)
Expand Down Expand Up @@ -278,18 +381,16 @@ def _create_field_operator_impl(
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

# the memory layout of the output field follows the field operator compute domain
domain_dims, domain_offset, domain_shape = get_field_layout(domain)
domain_indices = get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)
field_dims, field_origin, field_shape = get_field_layout(domain)
field_indices = get_domain_indices(field_dims, field_origin)
field_subset = dace_subsets.Range.from_indices(field_indices)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
if output_edge.result.gt_dtype != output_type.dtype:
raise TypeError(
f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}."
)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_shape = domain_shape
field_subset = domain_subset
else:
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
Expand All @@ -301,8 +402,8 @@ def _create_field_operator_impl(
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)
field_shape = [*field_shape, dataflow_output_desc.shape[0]]
field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype)
Expand All @@ -312,9 +413,7 @@ def _create_field_operator_impl(
output_edge.connect(map_exit, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(domain_dims, output_edge.result.gt_dtype),
offset=(domain_offset if set(domain_offset) != {0} else None),
field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin)
)


Expand Down Expand Up @@ -535,7 +634,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData:
outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc)
outer_node = state.add_access(outer)

return inner_data.make_copy(outer_node)
return FieldopData(outer_node, inner_data.gt_type, inner_data.origin)

result_temps = gtx_utils.tree_map(construct_output)(true_br_args)

Expand Down Expand Up @@ -696,7 +795,7 @@ def translate_literal(
data_type = node.type
data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type)

return FieldopData(data_node, data_type, offset=None)
return FieldopData(data_node, data_type, origin=())


def translate_make_tuple(
Expand Down Expand Up @@ -818,7 +917,7 @@ def translate_scalar_expr(
dace.Memlet(data=temp_name, subset="0"),
)

return FieldopData(temp_node, node.type, offset=None)
return FieldopData(temp_node, node.type, origin=())


def translate_symbol_ref(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)"))

input_memlets: dict[str, MemletExpr | ValueExpr] = {}
nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None

# define scalar or symbol for the condition value inside the nested SDFG
if isinstance(condition_value, SymbolExpr):
Expand Down Expand Up @@ -845,12 +846,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp

outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))}

# all free symbols are mapped to the symbols available in parent SDFG
nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols}
if isinstance(condition_value, SymbolExpr):
nsdfg_symbols_mapping["__cond"] = condition_value.value
nsdfg_node = self.state.add_nested_sdfg(
nsdfg,
self.sdfg,
inputs=set(input_memlets.keys()),
outputs=outputs,
symbol_mapping=None, # implicitly map all free symbols to the symbols available in parent SDFG
symbol_mapping=nsdfg_symbols_mapping,
)

for inner, input_expr in input_memlets.items():
Expand Down Expand Up @@ -1504,7 +1509,7 @@ def _make_unstructured_shift(
shifted_indices[neighbor_dim] = MemletExpr(
dc_node=offset_table_node,
gt_dtype=it.gt_dtype,
subset=dace_subsets.Indices([origin_index.value, offset_expr.value]),
subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"),
)
else:
# dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,16 @@
}


def builtin_cast(*args: Any) -> str:
val, target_type = args
def builtin_cast(val: str, target_type: str) -> str:
assert target_type in builtins.TYPE_BUILTINS
return MATH_BUILTINS_MAPPING[target_type].format(val)


def builtin_if(*args: Any) -> str:
cond, true_val, false_val = args
def builtin_if(cond: str, true_val: str, false_val: str) -> str:
return f"{true_val} if {cond} else {false_val}"


def builtin_tuple_get(*args: Any) -> str:
index, tuple_name = args
def builtin_tuple_get(index: str, tuple_name: str) -> str:
return f"{tuple_name}_{index}"


Expand All @@ -99,7 +96,7 @@ def make_const_list(arg: str) -> str:
return arg


GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = {
GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = {
"cast_": builtin_cast,
"if_": builtin_if,
"make_const_list": make_const_list,
Expand Down
Loading

0 comments on commit f0c67e6

Please sign in to comment.