From fecfde5289bd799b8c4ae0d3805830304ec4a88d Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 28 Nov 2021 15:39:48 -0500 Subject: [PATCH] Support for new `is_splat_index` trait In addition to the faster new and faster `ArrayInterface.to_indices` this now accounts for trailing index arguments that map to multiple dimensions (e.g., `A[.., CartesianIndex(2, 2)]`) --- Project.toml | 4 ++-- src/EllipsisNotation.jl | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 1372da9..1222cf4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,11 @@ name = "EllipsisNotation" uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" authors = ["Chris Rackauckas "] -version = "1.1.1" +version = "1.1.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" [compat] -julia = "1.5" ArrayInterface = "3" +julia = "1.5" diff --git a/src/EllipsisNotation.jl b/src/EllipsisNotation.jl index adc4e32..739b317 100644 --- a/src/EllipsisNotation.jl +++ b/src/EllipsisNotation.jl @@ -43,13 +43,11 @@ true module EllipsisNotation using ArrayInterface -using ArrayInterface: indices - import Base: to_indices, tail struct Ellipsis end -const .. = Ellipsis() +const .. = Ellipsis() @inline function to_indices(A, inds::NTuple{M, Any}, I::Tuple{Ellipsis, Vararg{Any, N}}) where {M,N} # Align the remaining indices to the tail of the `inds` @@ -57,10 +55,9 @@ const .. = Ellipsis() to_indices(A, inds, (colons..., tail(I)...)) end -Base.@propagate_inbounds function ArrayInterface.to_indices(A, inds::Tuple{Vararg{Any,M}}, I::Tuple{Ellipsis,Vararg{Any, N}}) where {M,N} - return ArrayInterface.to_indices(A, inds, (ntuple(i -> indices(inds[i]), Val(M-N))..., tail(I)...)) -end -ArrayInterface.to_indices(A, inds::Tuple{}, I::Tuple{Ellipsis}) = () +ArrayInterface.is_splat_index(::Type{Ellipsis}) = ArrayInterface.static(true) +ArrayInterface.ndims_index(::Type{Ellipsis}) = ArrayInterface.static(1) +ArrayInterface.to_index(x, ::Ellipsis) = ntuple(i -> ArrayInterface.indices(x, i), Val(ndims(x))) export ..