Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Embedded support for skip value connectivities #1441

Merged
merged 24 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ class ConnectivityKind(enum.Flag):


@extended_runtime_checkable
# type: ignore[misc] # DimT should be covariant, but break in another place
# type: ignore[misc] # DimT should be covariant, but breaks in another place
class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]):
@property
@abc.abstractmethod
Expand All @@ -749,6 +749,10 @@ def kind(self) -> ConnectivityKind:
@abc.abstractmethod
def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ...

@property
@abc.abstractmethod
def skip_value(self) -> Optional[core_defs.IntegralScalar]: ...

# Operators
def __abs__(self) -> Never:
raise TypeError("'ConnectivityField' does not support this operation.")
Expand Down Expand Up @@ -840,6 +844,7 @@ def _connectivity(
*,
domain: Optional[DomainLike] = None,
dtype: Optional[core_defs.DType] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> ConnectivityField:
raise NotImplementedError

Expand Down Expand Up @@ -918,6 +923,10 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]:
def codomain(self) -> DimT:
return self.dimension

@property
def skip_value(self) -> None:
return None

@functools.cached_property
def kind(self) -> ConnectivityKind:
return ConnectivityKind(0)
Expand Down Expand Up @@ -1083,4 +1092,4 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
#: Numeric value used to represent missing values in connectivities.
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
SKIP_VALUE: Final[int] = -1
_DEFAULT_SKIP_VALUE: Final[int] = -1
6 changes: 5 additions & 1 deletion src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def as_connectivity(
*,
allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None,
device: Optional[core_defs.Device] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
# copy=False, TODO
) -> common.ConnectivityField:
"""
Expand All @@ -330,6 +331,9 @@ def as_connectivity(
Raises:
ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape.
"""
assert (
skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE
) # TODO(havogt): not yet configurable
if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain):
domain = cast(Sequence[common.Dimension], domain)
if len(domain) != data.ndim:
Expand Down Expand Up @@ -359,7 +363,7 @@ def as_connectivity(
# TODO(havogt): consider adding MutableNDArrayObject
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
buffer.ndarray, codomain=codomain, domain=actual_domain, skip_value=skip_value
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)

Expand Down
151 changes: 102 additions & 49 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


try:
Expand Down Expand Up @@ -170,9 +171,12 @@ def remap(
if not common.is_connectivity_field(connectivity):
assert isinstance(connectivity, fbuiltins.FieldOffset)
connectivity = connectivity.as_connectivity_field()

assert common.is_connectivity_field(connectivity)

# Current implementation relies on skip_value == -1:
# if we assume the indexed array has at least one element, we wrap around without out of bounds
assert connectivity.skip_value is None or connectivity.skip_value == -1

# Compute the new domain
dim = connectivity.codomain
dim_idx = self.domain.dim_index(dim)
Expand Down Expand Up @@ -315,6 +319,7 @@ class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__
NdArrayField[common.DimsT, core_defs.IntegralScalar],
):
_codomain: common.DimT
_skip_value: Optional[core_defs.IntegralScalar]
Comment on lines 321 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope of this PR but just to not forget: it doesn't make sense to use a property for plain attribute access just because a property was defined in a parent class. We could add a simple utility to eve to disable parent properties by overwriting them with a non-data descriptor like this:

def disable_property(default=...) -> functools.cached_property:
    return functools.cached_property(lambda _: default)

then this should work:

Suggested change
_codomain: common.DimT
_skip_value: Optional[core_defs.IntegralScalar]
codomain: common.DimT = disable_property()
skip_value: Optional[core_defs.IntegralScalar] = disable_property()

Inside a dataclass, it would also work to define a field with a default like this:

    codomain: common.DimT = dataclasses.field(default=...)

although this would change the semantics by making codomain optional, where disable_property wouldn't affect the behavior of the dataclass at all.


@functools.cached_property
def _cache(self) -> dict:
Expand All @@ -329,6 +334,10 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig
def codomain(self) -> common.DimT:
return self._codomain

@property
def skip_value(self) -> Optional[core_defs.IntegralScalar]:
return self._skip_value

@functools.cached_property
def kind(self) -> common.ConnectivityKind:
kind = common.ConnectivityKind.MODIFY_STRUCTURE
Expand All @@ -349,6 +358,7 @@ def from_array( # type: ignore[override]
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> NdArrayConnectivityField:
domain = common.domain(domain)
xp = cls.array_ns
Expand All @@ -367,7 +377,12 @@ def from_array( # type: ignore[override]

assert isinstance(codomain, common.Dimension)

return cls(domain, array, codomain)
return cls(
domain,
array,
codomain,
_skip_value=skip_value,
)

def inverse_image(
self, image_range: common.UnitRange | common.NamedRange
Expand All @@ -390,47 +405,16 @@ def inverse_image(
assert isinstance(image_range, common.UnitRange)

assert common.UnitRange.is_finite(image_range)
restricted_mask = (self._ndarray >= image_range.start) & (
self._ndarray < image_range.stop
)
# indices of non-zero elements in each dimension
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask)

new_dims = []
non_contiguous_dims = []

for i, dim_nnz_indices in enumerate(nnz):
# Check if the indices are contiguous
first_data_index = dim_nnz_indices[0]
assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES)
last_data_index = dim_nnz_indices[-1]
assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES)
indices, counts = xp.unique(dim_nnz_indices, return_counts=True)
dim_range = self._domain[i]

if len(xp.unique(counts)) == 1 and (
len(indices) == last_data_index - first_data_index + 1
):
idx_offset = dim_range[1].start
start = idx_offset + first_data_index
assert common.is_int_index(start)
stop = idx_offset + last_data_index + 1
assert common.is_int_index(stop)
new_dims.append(
common.named_range(
(
dim_range[0],
(start, stop),
)
)
)
else:
non_contiguous_dims.append(dim_range[0])

if non_contiguous_dims:
raise ValueError(
f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'."
)
relative_ranges = _hypercube(self._ndarray, image_range, xp, self.skip_value)

if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]

self._cache[cache_key] = new_dims

Expand All @@ -444,14 +428,49 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
xp = cls.array_ns
new_domain, buffer_slice = self._slice(index)
new_buffer = xp.asarray(self.ndarray[buffer_slice])
restricted_connectivity = cls(new_domain, new_buffer, self.codomain)
restricted_connectivity = cls(new_domain, new_buffer, self.codomain, self.skip_value)
self._cache[cache_key] = restricted_connectivity

return restricted_connectivity

__getitem__ = restrict


def _hypercube(
havogt marked this conversation as resolved.
Show resolved Hide resolved
index_array: core_defs.NDArrayObject,
image_range: common.UnitRange,
xp: ModuleType,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> Optional[list[common.UnitRange]]:
"""
Return the hypercube that contains all indices in `index_array` that are within `image_range`, or `None` if no such hypercube exists.

If `skip_value` is given, the selected values are ignored. It returns the smallest hypercube.
A bigger hypercube could be constructed by adding lines that contain only `skip_value`s.
Example:
index_array = 0 1 -1
3 4 -1
-1 -1 -1
skip_value = -1
would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3].
"""
select_mask = (index_array >= image_range.start) & (index_array < image_range.stop)

nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask)

slices = tuple(
slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz
)
hcube = select_mask[tuple(slices)]
if skip_value is not None:
ignore_mask = index_array == skip_value
hcube |= ignore_mask[tuple(slices)]
if not xp.all(hcube):
return None

return [common.UnitRange(s.start, s.stop) for s in slices]


# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(
Expand Down Expand Up @@ -483,31 +502,65 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[
def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
...,
NdArrayField[common.DimsT, core_defs.ScalarT],
]:
def _builtin_op(
field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension
) -> NdArrayField[common.DimsT, core_defs.ScalarT]:
xp = field.array_ns

if not axis.kind == common.DimensionKind.LOCAL:
raise ValueError("Can only reduce local dimensions.")
if axis not in field.domain.dims:
raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.")
if len([d for d in field.domain.dims if d.kind is common.DimensionKind.LOCAL]) > 1:
raise NotImplementedError(
"Reducing a field with more than one local dimension is not supported."
)
reduce_dim_index = field.domain.dims.index(axis)
current_offset_provider = embedded_context.offset_provider.get(None)
assert current_offset_provider is not None
offset_definition = current_offset_provider[
axis.value
] # assumes offset and local dimension have same name
egparedes marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider)
new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis])

broadcast_slice = tuple(
slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis
for d in field.domain.dims
)
masked_array = xp.where(
xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE,
field.ndarray,
initial_value_op(field),
)

return field.__class__.from_array(
getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index),
getattr(xp, array_builtin_name)(
masked_array,
axis=reduce_dim_index,
),
domain=new_domain,
)

_builtin_op.__name__ = builtin_name
return _builtin_op


NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum"))
NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max"))
NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min"))
NdArrayField.register_builtin_func(
fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0))
)
NdArrayField.register_builtin_func(
fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray))
)
NdArrayField.register_builtin_func(
fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray))
)


# -- Concrete array implementations --
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def __post_init__(self):
f"The following closure variables are undefined: {', '.join(undefined_symbols)}."
)

@property
egparedes marked this conversation as resolved.
Show resolved Hide resolved
def __name__(self) -> str:
return self.definition.__name__

@functools.cached_property
def __gt_allocator__(
self,
Expand Down Expand Up @@ -603,6 +607,10 @@ def from_function(
operator_attributes=operator_attributes,
)

@property
egparedes marked this conversation as resolved.
Show resolved Hide resolved
def __name__(self) -> str:
return self.definition.__name__

def __gt_type__(self) -> ts.CallableType:
type_ = self.foast_node.type
assert isinstance(type_, ts.CallableType)
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,14 @@ def as_connectivity_field(self):
if common.is_connectivity_field(offset_definition):
connectivity = offset_definition
elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider):
assert not offset_definition.has_skip_values
connectivity = gtx.as_connectivity(
domain=self.target,
codomain=self.source,
data=offset_definition.table,
dtype=offset_definition.index_type,
skip_value=(
common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None
),
)
else:
raise NotImplementedError()
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def _construct_itir_domain_arg(
node_domain: Optional[past.Expr],
slices: Optional[list[past.Slice]] = None,
) -> itir.FunCall:
domain_args = []

assert isinstance(out_field.type, ts.TypeSpec)
out_field_types = type_info.primitive_constituents(out_field.type).to_list()
Expand All @@ -246,6 +245,8 @@ def _construct_itir_domain_arg(
" caught in type deduction already."
)

domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
Expand All @@ -271,11 +272,17 @@ def _construct_itir_domain_arg(
args=[itir.AxisLiteral(value=dim.value), lower, upper],
)
)
domain_args_kind.append(dim.kind)

if self.grid_type == GridType.CARTESIAN:
domain_builtin = "cartesian_domain"
elif self.grid_type == GridType.UNSTRUCTURED:
domain_builtin = "unstructured_domain"
# for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical)
if domain_args_kind[0] == DimensionKind.VERTICAL:
assert len(domain_args) == 2
assert domain_args_kind[1] == DimensionKind.HORIZONTAL
domain_args[0], domain_args[1] = domain_args[1], domain_args[0]
else:
raise AssertionError()

Expand Down
Loading
Loading