From e344afd2508afd727bca9075d3265d63995b7a33 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:34:19 +0100 Subject: [PATCH] refactor[next]: NamedRange/NamedIndex tuple to NamedTuple (#1490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change NamedRange and NamedIndex from being a plain tuple to a NamedTuple for cleaner element access. --------- Co-authored-by: Hannes Vogt Co-authored-by: Enrique González Paredes --- src/gt4py/next/common.py | 122 +++++++++--------- src/gt4py/next/embedded/common.py | 42 +++--- src/gt4py/next/embedded/nd_array_field.py | 34 ++--- src/gt4py/next/embedded/operators.py | 14 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/embedded.py | 27 ++-- .../unit_tests/embedded_tests/test_common.py | 38 +++--- .../embedded_tests/test_nd_array_field.py | 59 +++++---- .../iterator_tests/test_embedded_internals.py | 4 +- tests/next_tests/unit_tests/test_common.py | 100 +++++++------- 10 files changed, 229 insertions(+), 213 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 7f4d5b7b97..2936e4163a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -34,6 +34,7 @@ ClassVar, Final, Generic, + NamedTuple, Never, Optional, ParamSpec, @@ -84,7 +85,7 @@ def __str__(self) -> str: return f"{self.value}[{self.kind}]" def __call__(self, val: int) -> NamedIndex: - return self, val + return NamedIndex(self, val) class Infinity(enum.Enum): @@ -248,9 +249,16 @@ def __str__(self) -> str: FiniteUnitRange: TypeAlias = UnitRange[int, int] +_Rng = TypeVar( + "_Rng", + FiniteUnitRange, + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) RangeLike: TypeAlias = ( - UnitRange + _Rng | range | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] | core_defs.IntegralScalar @@ -282,10 +290,26 @@ def unit_range(r: RangeLike) -> UnitRange: raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") +class NamedRange(NamedTuple, Generic[_Rng]): + dim: Dimension + unit_range: _Rng + + def __str__(self) -> str: + return f"{self.dim}={self.unit_range}" + + IntIndex: TypeAlias = int | core_defs.IntegralScalar -NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple -NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple -FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple + + +class NamedIndex(NamedTuple): + dim: Dimension + value: IntIndex + + def __str__(self) -> str: + return f"{self.dim}={self.value}" + + +FiniteNamedRange: TypeAlias = NamedRange[FiniteUnitRange] RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice @@ -304,41 +328,22 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: - return ( - isinstance(v, tuple) - and len(v) == 2 - and isinstance(v[0], Dimension) - and isinstance(v[1], UnitRange) - ) - - def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: - return UnitRange.is_finite(v[1]) + return UnitRange.is_finite(v.unit_range) -def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: - return ( - isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) +def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[slice]: + return isinstance(obj, slice) and ( + isinstance(obj.start, NamedIndex) and isinstance(obj.stop, NamedIndex) ) -def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]: - return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop)) - - def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: - return ( - is_int_index(v) - or is_named_range(v) - or is_named_index(v) - or isinstance(v, slice) - or v is Ellipsis - ) + return is_int_index(v) or isinstance(v, (NamedRange, NamedIndex, slice)) or v is Ellipsis def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: - return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + return isinstance(v, Sequence) and all(isinstance(e, (NamedRange, NamedIndex)) for e in v) def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: @@ -356,20 +361,13 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: - return (v[0], unit_range(v[1])) - - -_Rng = TypeVar( - "_Rng", - UnitRange[int, int], - UnitRange[Infinity, int], - UnitRange[int, Infinity], - UnitRange[Infinity, Infinity], -) + if isinstance(v, NamedRange): + return v + return NamedRange(v[0], unit_range(v[1])) @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): +class Domain(Sequence[NamedRange[_Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] @@ -377,7 +375,7 @@ class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): def __init__( self, - *args: tuple[Dimension, _Rng], + *args: NamedRange[_Rng], dims: Optional[Sequence[Dimension]] = None, ranges: Optional[Sequence[_Rng]] = None, ) -> None: @@ -406,7 +404,7 @@ def __init__( object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) else: - if not all(is_named_range(arg) for arg in args): + if not all(isinstance(arg, NamedRange) for arg in args): raise ValueError( f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." ) @@ -437,17 +435,17 @@ def is_empty(self) -> bool: return any(rng.is_empty() for rng in self.ranges) @overload - def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... + def __getitem__(self, index: int) -> NamedRange: ... @overload def __getitem__(self, index: slice) -> Self: ... @overload - def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ... + def __getitem__(self, index: Dimension) -> NamedRange: ... def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: if isinstance(index, int): - return self.dims[index], self.ranges[index] + return NamedRange(dim=self.dims[index], unit_range=self.ranges[index]) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -455,7 +453,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return self.dims[index_pos], self.ranges[index_pos] + return NamedRange(dim=self.dims[index_pos], unit_range=self.ranges[index_pos]) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: @@ -470,10 +468,12 @@ def __and__(self, other: Domain) -> Domain: >>> I = Dimension("I") >>> J = Dimension("J") - >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + >>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6))) Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) - >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + >>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain( + ... NamedRange(I, UnitRange(1, 6)) + ... ) Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) @@ -487,7 +487,7 @@ def __and__(self, other: Domain) -> Domain: return Domain(dims=broadcast_dims, ranges=intersected_ranges) def __str__(self) -> str: - return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + return f"Domain({', '.join(f'{e}' for e in self)})" def dim_index(self, dim: Dimension) -> Optional[int]: return self.dims.index(dim) if dim in self.dims else None @@ -503,7 +503,7 @@ def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return self.replace(index, *named_ranges) def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: - assert all(is_named_range(nr) for nr in named_ranges) + assert all(isinstance(nr, NamedRange) for nr in named_ranges) if isinstance(index, Dimension): dim_index = self.dim_index(index) if dim_index is None: @@ -515,9 +515,10 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: ) if index < 0: index += len(self.dims) - new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) - dims = self.dims[:index] + new_dims + self.dims[index + 1 :] - ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :] + new_dims = (arg.dim for arg in named_ranges) if len(named_ranges) > 0 else () + new_ranges = (arg.unit_range for arg in named_ranges) if len(named_ranges) > 0 else () + dims = self.dims[:index] + tuple(new_dims) + self.dims[index + 1 :] + ranges = self.ranges[:index] + tuple(new_ranges) + self.ranges[index + 1 :] return Domain(dims=dims, ranges=ranges) @@ -559,10 +560,7 @@ def domain(domain_like: DomainLike) -> Domain: if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( dims=tuple(domain_like.keys()), - ranges=tuple( - UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition - for s in domain_like.values() - ), + ranges=tuple(UnitRange(0, s) for s in domain_like.values()), ) return Domain( dims=tuple(domain_like.keys()), @@ -949,15 +947,15 @@ def from_offset( def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: if not isinstance(image_range, UnitRange): - if image_range[0] != self.codomain: + if image_range.dim != self.codomain: raise ValueError( - f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range[1] + image_range = image_range.unit_range assert isinstance(image_range, UnitRange) - return ((self.codomain, image_range - self.offset),) + return (named_range((self.codomain, image_range - self.offset)),) def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: raise NotImplementedError() diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 6642d9a055..cdb0d3a5fd 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -51,7 +51,7 @@ def _relative_sub_domain( if isinstance(idx, slice): try: sliced = _slice_range(rng, idx) - named_ranges.append((dim, sliced)) + named_ranges.append(common.NamedRange(dim, sliced)) except IndexError as ex: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=idx, dim=dim @@ -76,14 +76,14 @@ def _absolute_sub_domain( for i, (dim, rng) in enumerate(domain): if (pos := _find_index_of_dim(dim, index)) is not None: named_idx = index[pos] - idx = named_idx[1] + _, idx = named_idx if isinstance(idx, common.UnitRange): if not idx <= rng: raise embedded_exceptions.IndexOutOfBounds( domain=domain, indices=index, index=named_idx, dim=dim ) - named_ranges.append((dim, idx)) + named_ranges.append(common.NamedRange(dim, idx)) else: # not in new domain assert common.is_int_index(idx) @@ -93,7 +93,7 @@ def _absolute_sub_domain( ) else: # dimension not mentioned in slice - named_ranges.append((dim, domain.ranges[i])) + named_ranges.append(common.NamedRange(dim, domain.ranges[i])) return common.Domain(*named_ranges) @@ -137,13 +137,13 @@ def restrict_to_intersection( """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) intersection_without_ignore_dims = domain_intersection(*[ - common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) + common.Domain(*[nr for nr in domain if nr.dim not in ignore_dims_tuple]) for domain in domains ]) return tuple( common.Domain(*[ - (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) - for d, r in domain + (nr if nr.dim in ignore_dims_tuple else intersection_without_ignore_dims[nr.dim]) + for nr in domain ]) for domain in domains ) @@ -151,9 +151,9 @@ def restrict_to_intersection( def iterate_domain( domain: common.Domain, -) -> Iterator[tuple[tuple[common.Dimension, int]]]: - for i in itertools.product(*[list(r) for r in domain.ranges]): - yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]` +) -> Iterator[tuple[common.NamedIndex]]: + for idx in itertools.product(*(list(r) for r in domain.ranges)): + yield tuple(common.NamedIndex(d, i) for d, i in zip(domain.dims, idx)) # type: ignore[misc] # trust me, `idx` is `tuple[int, ...]` def _expand_ellipsis( @@ -169,7 +169,7 @@ def _expand_ellipsis( def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) + return input_range start = ( input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop @@ -209,20 +209,16 @@ def _named_slice_to_named_range( ) -> common.NamedRange | common.NamedSlice: assert hasattr(idx, "start") and hasattr(idx, "stop") if common.is_named_slice(idx): - idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = ( - idx.start[0], # type: ignore[attr-defined] - idx.start[1], # type: ignore[attr-defined] - idx.stop[0], # type: ignore[attr-defined] - idx.stop[1], # type: ignore[attr-defined] - ) - if idx_start_0 != idx_stop_0: + start_dim, start_value = idx.start + stop_dim, stop_value = idx.stop + if start_dim != stop_dim: raise IndexError( - f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." + f"Dimensions slicing mismatch between '{start_dim.value}' and '{stop_dim.value}'." ) - assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) - return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) - if common.is_named_index(idx.start) and idx.stop is None: + assert isinstance(start_value, int) and isinstance(stop_value, int) + return common.NamedRange(start_dim, common.UnitRange(start_value, stop_value)) + if isinstance(idx.start, common.NamedIndex) and idx.stop is None: raise IndexError(f"Upper bound needs to be specified for {idx}.") - if common.is_named_index(idx.stop) and idx.start is None: + if isinstance(idx.stop, common.NamedIndex) and idx.start is None: raise IndexError(f"Lower bound needs to be specified for {idx}.") return idx diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e884c61f36..e7f34bb2a2 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -24,7 +24,7 @@ from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar +from gt4py.eve.extended_typing import Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common from gt4py.next.embedded import ( common as embedded_common, @@ -122,7 +122,7 @@ def shape(self) -> tuple[int, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) - return tuple(-r.start for _, r in self._domain) + return tuple(-r.start for r in self._domain.ranges) @property def ndarray(self) -> core_defs.NDArrayObject: @@ -137,7 +137,7 @@ def asnumpy(self) -> np.ndarray: def as_scalar(self) -> core_defs.ScalarT: if self.domain.ndim != 0: raise ValueError( - "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." ) return self.ndarray.item() @@ -196,7 +196,7 @@ def remap( if dim_idx is None: raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") - current_range: common.UnitRange = self.domain[dim_idx][1] + current_range: common.UnitRange = self.domain[dim_idx].unit_range new_ranges = connectivity.inverse_image(current_range) new_domain = self.domain.replace(dim_idx, *new_ranges) @@ -413,12 +413,12 @@ def inverse_image( if not isinstance( image_range, common.UnitRange ): # TODO(havogt): cleanup duplication with CartesianConnectivity - if image_range[0] != self.codomain: + if image_range.dim != self.codomain: raise ValueError( - f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range[1] + image_range = image_range.unit_range assert isinstance(image_range, common.UnitRange) @@ -603,14 +603,14 @@ def _intersect_fields( def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: if not domains: return common.Domain() - dim_start = domains[0][dim][1].start + dim_start = domains[0][dim].unit_range.start dim_stop = dim_start for domain in domains: - if not domain[dim][1].start == dim_stop: + if not domain[dim].unit_range.start == dim_stop: return None else: - dim_stop = domain[dim][1].stop - return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop))) + dim_stop = domain[dim].unit_range.stop + return domains[0].replace(dim, common.NamedRange(dim, common.UnitRange(dim_start, dim_stop))) def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: @@ -688,7 +688,7 @@ def _concat_where( if transformed: return _concat(*transformed, dim=mask_dim) else: - result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) + result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) result_array = xp.empty(result_domain.shape) return cls_.from_array(result_array, domain=result_domain) @@ -722,7 +722,7 @@ def _builtin_op( axis.value ] # assumes offset and local dimension have same name assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) - new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis @@ -820,10 +820,10 @@ def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) for dim in new_dimensions: if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - named_ranges.append((dim, field.domain[pos][1])) + named_ranges.append(common.NamedRange(dim, field.domain[pos].unit_range)) else: domain_slice.append(None) # np.newaxis - named_ranges.append((dim, common.UnitRange.infinite())) + named_ranges.append(common.NamedRange(dim, common.UnitRange.infinite())) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -849,7 +849,7 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA def _get_slices_from_domain_slice( domain: common.Domain, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex], ) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. @@ -869,7 +869,7 @@ def _get_slices_from_domain_slice( for pos_old, (dim, _) in enumerate(domain): if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: - index_or_range = domain_slice[pos][1] + _, index_or_range = domain_slice[pos] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: slice_indices.append(slice(None)) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index c5f5fd0503..b88083e7c2 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -51,14 +51,14 @@ def __call__( # type: ignore[override] | tuple[common.Field[Any, core_defs.ScalarT] | tuple, ...] ): scan_range = embedded_context.closure_column_range.get() - assert self.axis == scan_range[0] - scan_axis = scan_range[0] + assert self.axis == scan_range.dim + scan_axis = scan_range.dim all_args = [*args, *kwargs.values()] domain_intersection = _intersect_scan_args(*all_args) - non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr.dim != scan_axis]) out_domain = common.Domain(*[ - scan_range if nr[0] == scan_axis else nr for nr in domain_intersection + scan_range if nr.dim == scan_axis else nr for nr in domain_intersection ]) if scan_axis not in out_domain.dims: # even if the scan dimension is not in the input, we can scan over it @@ -69,8 +69,8 @@ def __call__( # type: ignore[override] def scan_loop(hpos: Sequence[common.NamedIndex]) -> None: acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init - for k in scan_range[1] if self.forward else reversed(scan_range[1]): - pos = (*hpos, (scan_axis, k)) + for k in scan_range.unit_range if self.forward else reversed(scan_range.unit_range): + pos = (*hpos, common.NamedIndex(scan_axis, k)) new_args = [_tuple_at(pos, arg) for arg in args] new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} acc = self.fun(acc, *new_args, **new_kwargs) # type: ignore[arg-type] # need to express that the first argument is the same type as the return @@ -134,7 +134,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: - vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 5c0da54ab8..34562ffdcb 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -344,7 +344,7 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: ) or common.is_connectivity_field(offset_definition): unrestricted_connectivity = self.as_connectivity_field() assert unrestricted_connectivity.domain.ndim > 1 - named_index = (self.target[-1], offset) + named_index = common.NamedIndex(self.target[-1], offset) connectivity = unrestricted_connectivity[named_index] else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b7cf36187d..c9552e7138 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -208,7 +208,9 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 column_range: common.NamedRange = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) + self.data = ( + data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data) + ) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -749,7 +751,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().unit_range assert column_range is not None col: list[ @@ -826,7 +828,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.unit_range)) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -867,7 +869,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -909,16 +911,16 @@ def _translate_named_indices( domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): if isinstance(v, range): - domain_slice.append((d, common.UnitRange(v.start, v.stop))) + domain_slice.append(common.NamedRange(d, common.UnitRange(v.start, v.stop))) elif isinstance(v, list): assert len(v) == 1 # only 1 sparse dimension is supported assert common.is_int_index( v[0] ) # derefing a concrete element in a sparse field, not a slice - domain_slice.append((d, v[0])) + domain_slice.append(common.NamedIndex(d, v[0])) else: assert common.is_int_index(v) - domain_slice.append((d, v)) + domain_slice.append(common.NamedIndex(d, v)) return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: @@ -1060,7 +1062,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: if self._cur_index is None: - return common.Domain((self._dimension, common.UnitRange.infinite())) + return common.Domain(common.NamedRange(self._dimension, common.UnitRange.infinite())) else: return common.Domain() @@ -1092,11 +1094,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> Self: - if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off + if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item): + assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) - return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work + return self.__class__(self._dimension, r) # TODO set a domain... raise NotImplementedError() @@ -1492,7 +1495,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().unit_range if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1545,7 +1548,7 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = ( + column_range = common.NamedRange( column_axis, common.UnitRange(column.col_range.start, column.col_range.stop), ) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 9765273f94..111622ac42 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -17,7 +17,7 @@ import pytest from gt4py.next import common -from gt4py.next.common import UnitRange +from gt4py.next.common import UnitRange, NamedIndex, NamedRange from gt4py.next.embedded import exceptions as embedded_exceptions from gt4py.next.embedded.common import ( _slice_range, @@ -53,12 +53,12 @@ def test_slice_range(rng, slce, expected): [ ([(I, (2, 5))], 1, []), ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), - ([(I, (2, 5))], (I, 2), []), - ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (2, 5))], NamedIndex(I, 2), []), + ([(I, (2, 5))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]), ([(I, (-2, 3))], 1, []), ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), - ([(I, (-2, 3))], (I, 1), []), - ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], NamedIndex(I, 1), []), + ([(I, (-2, 3))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]), ([(I, (-2, 3))], -5, []), ([(I, (-2, 3))], -6, IndexError), ([(I, (-2, 3))], slice(-7, -6), IndexError), @@ -67,10 +67,10 @@ def test_slice_range(rng, slce, expected): ([(I, (-2, 3))], 5, IndexError), ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), ([(I, (-2, 3))], slice(5, 6), IndexError), - ([(I, (-2, 3))], (I, -3), IndexError), - ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), - ([(I, (-2, 3))], (I, 3), IndexError), - ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ([(I, (-2, 3))], NamedIndex(I, -3), IndexError), + ([(I, (-2, 3))], NamedRange(I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], NamedIndex(I, 3), IndexError), + ([(I, (-2, 3))], NamedRange(I, UnitRange(3, 4)), IndexError), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], 2, @@ -83,32 +83,32 @@ def test_slice_range(rng, slce, expected): ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, 2), + NamedIndex(I, 2), [(J, (3, 6)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, UnitRange(2, 3)), + NamedRange(I, UnitRange(2, 3)), [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, 3), + NamedIndex(J, 3), [(I, (2, 5)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, UnitRange(4, 5)), + NamedRange(J, UnitRange(4, 5)), [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, 3), (I, 2)), + (NamedIndex(J, 3), NamedIndex(I, 2)), [(K, (4, 7))], ), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, UnitRange(4, 5)), (I, 2)), + (NamedRange(J, UnitRange(4, 5)), NamedIndex(I, 2)), [(J, (4, 5)), (K, (4, 7))], ), ( @@ -147,8 +147,8 @@ def test_sub_domain(domain, index, expected): def test_iterate_domain(): domain = common.domain({I: 2, J: 3}) ref = [] - for i in domain[I][1]: - for j in domain[J][1]: + for i in domain[I].unit_range: + for j in domain[J].unit_range: ref.append(((I, i), (J, j))) testee = list(iterate_domain(domain)) @@ -159,10 +159,10 @@ def test_iterate_domain(): @pytest.mark.parametrize( "slices, expected", [ - [slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)], + [slice(I(3), I(4)), (NamedRange(I, common.UnitRange(3, 4)),)], [ (slice(J(3), J(6)), slice(I(3), I(5))), - ((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))), + (NamedRange(J, common.UnitRange(3, 6)), NamedRange(I, common.UnitRange(3, 5))), ], [slice(I(1), J(7)), IndexError], [ diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 7c932533a6..7171bb5ecc 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -21,7 +21,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange +from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -430,7 +430,7 @@ def test_field_broadcast(new_dims, field, expected_domain): @pytest.mark.parametrize( "domain_slice", [ - ((D0, UnitRange(0, 10)),), + (NamedRange(D0, UnitRange(0, 10)),), common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), ], ) @@ -446,7 +446,7 @@ def test_get_slices_with_named_index(): field_domain = common.Domain( dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - named_index = ((D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) + named_index = (NamedRange(D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) slices = _get_slices_from_domain_slice(field_domain, named_index) assert slices == (slice(0, 10, None), 2, 3) @@ -465,34 +465,34 @@ def test_get_slices_invalid_type(): [ ( ( - (D0, UnitRange(7, 9)), - (D1, UnitRange(8, 10)), + NamedRange(D0, UnitRange(7, 9)), + NamedRange(D1, UnitRange(8, 10)), ), (D0, D1, D2), (2, 2, 15), ), ( ( - (D0, UnitRange(7, 9)), - (D2, UnitRange(12, 20)), + NamedRange(D0, UnitRange(7, 9)), + NamedRange(D2, UnitRange(12, 20)), ), (D0, D1, D2), (2, 10, 8), ), (common.Domain(dims=(D0,), ranges=(UnitRange(7, 9),)), (D0, D1, D2), (2, 10, 15)), - (((D0, 8),), (D1, D2), (10, 15)), - (((D1, 9),), (D0, D2), (5, 15)), - (((D2, 11),), (D0, D1), (5, 10)), + ((NamedIndex(D0, 8),), (D1, D2), (10, 15)), + ((NamedIndex(D1, 9),), (D0, D2), (5, 15)), + ((NamedIndex(D2, 11),), (D0, D1), (5, 10)), ( ( - (D0, 8), - (D1, UnitRange(8, 10)), + NamedIndex(D0, 8), + NamedRange(D1, UnitRange(8, 10)), ), (D1, D2), (2, 15), ), - ((D0, 5), (D1, D2), (10, 15)), - ((D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), + (NamedIndex(D0, 5), (D1, D2), (10, 15)), + (NamedRange(D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), ], ) def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): @@ -513,7 +513,10 @@ def test_absolute_indexing_dim_sliced(): ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field_1 = field[D1(8) : D1(10), D0(5) : D0(9)] - expected = field[(D0, UnitRange(5, 9)), (D1, UnitRange(8, 10))] + expected = field[ + NamedRange(dim=D0, unit_range=UnitRange(5, 9)), + NamedRange(dim=D1, unit_range=UnitRange(8, 10)), + ] assert common.is_field(indexed_field_1) assert indexed_field_1 == expected @@ -525,7 +528,7 @@ def test_absolute_indexing_dim_sliced_single_slice(): ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field_1 = field[D2(11)] - indexed_field_2 = field[(D2, 11)] + indexed_field_2 = field[NamedIndex(D2, 11)] assert common.is_field(indexed_field_1) assert indexed_field_1 == indexed_field_2 @@ -554,7 +557,7 @@ def test_absolute_indexing_value_return(): domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((D0, 12), (D1, 6)) + named_index = (NamedIndex(D0, 12), NamedIndex(D1, 6)) assert common.is_field(field) value = field[named_index] @@ -568,23 +571,27 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 4))), + Domain(NamedRange(D0, UnitRange(5, 10)), NamedRange(D1, UnitRange(2, 4))), + ), + ( + (slice(None, 5),), + (5, 10), + Domain(NamedRange(D0, UnitRange(5, 10)), NamedRange(D1, UnitRange(2, 12))), ), - ((slice(None, 5),), (5, 10), Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 12)))), ( (Ellipsis, 1), (10,), - Domain((D0, UnitRange(5, 15))), + Domain(NamedRange(D0, UnitRange(5, 15))), ), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((D0, UnitRange(7, 8)), (D1, UnitRange(7, 9))), + Domain(NamedRange(D0, UnitRange(7, 8)), NamedRange(D1, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((D0, UnitRange(6, 7))), + Domain(NamedRange(D0, UnitRange(6, 7))), ), ], ) @@ -693,7 +700,9 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common._field(np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(0, 10)))), + common._field( + np.ones((10,)) * 42.0, domain=common.Domain(NamedRange(D1, UnitRange(0, 10))) + ), ), ], ) @@ -718,7 +727,7 @@ def test_setitem_wrong_domain(): ) value_incompatible = common._field( - np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(-5, 5))) + np.ones((10,)) * 42.0, domain=common.Domain(NamedRange(D1, UnitRange(-5, 5))) ) with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): @@ -751,7 +760,7 @@ def test_connectivity_field_inverse_image(): # Test codomain with pytest.raises(ValueError, match="does not match the codomain dimension"): - e2v_conn.inverse_image((E, UnitRange(1, 2))) + e2v_conn.inverse_image(NamedRange(E, UnitRange(1, 2))) def test_connectivity_field_inverse_image_2d_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 9238cd4f7a..ec6e613529 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -62,7 +62,9 @@ def test_func(data_a: int, data_b: int): embedded.column_range_cvar.set(range(2, 999)) _run_within_context( lambda: test_func(2, 3), - column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + column_range=common.NamedRange( + common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3) + ), ) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index ce940131c3..1aeb51cb30 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -24,6 +24,7 @@ UnitRange, domain, named_range, + NamedRange, promote_dims, unit_range, ) @@ -37,7 +38,11 @@ @pytest.fixture def a_domain(): - return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) + return Domain( + NamedRange(IDim, UnitRange(0, 10)), + NamedRange(JDim, UnitRange(5, 15)), + NamedRange(KDim, UnitRange(20, 30)), + ) @pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) @@ -261,10 +266,10 @@ def test_domain_length(a_domain): "empty_domain, expected", [ (Domain(), False), - (Domain((IDim, UnitRange(0, 10))), False), - (Domain((IDim, UnitRange(0, 0))), True), - (Domain((IDim, UnitRange(0, 0)), (JDim, UnitRange(0, 1))), True), - (Domain((IDim, UnitRange(0, 1)), (JDim, UnitRange(0, 0))), True), + (Domain(NamedRange(IDim, UnitRange(0, 10))), False), + (Domain(NamedRange(IDim, UnitRange(0, 0))), True), + (Domain(NamedRange(IDim, UnitRange(0, 0)), NamedRange(JDim, UnitRange(0, 1))), True), + (Domain(NamedRange(IDim, UnitRange(0, 1)), NamedRange(JDim, UnitRange(0, 0))), True), ], ) def test_empty_domain(empty_domain, expected): @@ -436,89 +441,92 @@ def test_domain_pop(): # Valid index and named ranges ( 0, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("X"), UnitRange(100, 110)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), ( 1, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), ( -1, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), ), ), ( Dimension("J"), - [(Dimension("X"), UnitRange(100, 110)), (Dimension("Z"), UnitRange(100, 110))], + [ + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("Z"), UnitRange(100, 110)), + ], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("X"), UnitRange(100, 110)), - (Dimension("Z"), UnitRange(100, 110)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("X"), UnitRange(100, 110)), + NamedRange(Dimension("Z"), UnitRange(100, 110)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ), # Invalid indices ( 3, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), IndexError, ), ( -4, - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), IndexError, ), ( Dimension("Foo"), - [(Dimension("X"), UnitRange(100, 110))], + [NamedRange(Dimension("X"), UnitRange(100, 110))], Domain( - (Dimension("I"), UnitRange(0, 10)), - (Dimension("J"), UnitRange(0, 10)), - (Dimension("K"), UnitRange(0, 10)), + NamedRange(Dimension("I"), UnitRange(0, 10)), + NamedRange(Dimension("J"), UnitRange(0, 10)), + NamedRange(Dimension("K"), UnitRange(0, 10)), ), ValueError, ),