From fe4dc89f9fb86a877b137d6509845ff887669d25 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Wed, 2 Jun 2021 13:11:46 +0200 Subject: [PATCH] fix strides for reshaped views of abstract vectors, cf. #160 --- src/stridelayout.jl | 25 ++++++++++++++++++----- test/runtests.jl | 49 +++++++++++++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index e20892bca..64f6c4c82 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -88,7 +88,7 @@ end contiguous_axis(::Type{T}) -> StaticInt{N} Returns the axis of an array of type `T` containing contiguous data. -If no axis is contiguous, it returns `StaticInt{-1}`. +If no axis is contiguous, it returns a `StaticInt{-1}`. If unknown, it returns `nothing`. """ contiguous_axis(x) = contiguous_axis(typeof(x)) @@ -297,7 +297,7 @@ contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A} """ is_column_major(A) -> True/False -Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`. +Returns `True()` if elements of `A` are stored in column major order. Otherwise returns `False()`. """ is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A)) is_column_major(sr::Nothing, cbs) = False() @@ -310,10 +310,11 @@ _is_column_major(sr::R, cbs::StaticInt) where {R} = False() _is_column_major(sr::R, cbs::Union{StaticInt{0},StaticInt{-1}}) where {R} = is_increasing(sr) """ - dense_dims(::Type{T}) -> NTuple{N,Bool} + dense_dims(::Type{<:AbstractArray{N}}) -> NTuple{N,Bool} Returns a tuple of indicators for whether each axis is dense. -An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`. +An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` +where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`. """ dense_dims(x) = dense_dims(typeof(x)) function dense_dims(::Type{T}) where {T} @@ -359,7 +360,7 @@ end if VERSION ≥ v"1.6.0-DEV.1581" @inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}} ddb = dense_dims(B) - IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb)) + IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb)) end end @@ -464,13 +465,16 @@ julia> A = rand(3,4); julia> ArrayInterface.strides(A) (static(1), 3) +``` Additionally, the behavior differs from `Base.strides` for adjoint vectors: +```julia julia> x = rand(5); julia> ArrayInterface.strides(x') (static(1), static(1)) +``` This is to support the pattern of using just the first stride for linear indexing, `x[i]`, while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`. @@ -485,6 +489,17 @@ function strides(x) return Base.strides(x) end end + +# Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160 +# TODO: Should be generalized to reshaped arrays wrapping more general array types +function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector} + if defines_strides(A) + return size_to_strides(size(A), first(strides(parent(A)))) + else + return Base.strides(A) + end +end + @inline bmap(f::F, t::Tuple{}, x::Number) where {F} = () @inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), ) @inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...) diff --git a/test/runtests.jl b/test/runtests.jl index 5588e3004..0b169996d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -207,22 +207,22 @@ end @testset "Range Interface" begin @testset "Range Constructors" begin @test @inferred(StaticInt(1):StaticInt(10)) == 1:10 - @test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10 + @test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10 @test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10 @test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10 - @test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10 + @test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10 @test @inferred(1:2:StaticInt(10)) == 1:2:10 @test @inferred(1:StaticInt(2):10) == 1:2:10 - @test @inferred(StaticInt(1):2:10) == 1:2:10 - @test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10 + @test @inferred(StaticInt(1):2:10) == 1:2:10 + @test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10 @test @inferred(UInt(1):StaticInt(1):StaticInt(10)) === 1:StaticInt(10) @test @inferred(ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:10)) == 1:10 @test @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) == 1:10 @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) - @test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10 - @test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10 + @test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10 + @test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10 @test @inferred(ArrayInterface.OptionallyStaticStepRange(1:10)) == 1:1:10 @test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange(1:2:10) @@ -331,7 +331,6 @@ using OffsetArrays @test @inferred(ArrayInterface.defines_strides(D1)) @test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1))) @test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}})) - @test @inferred(device(A)) === ArrayInterface.CPUPointer() @test @inferred(device(B)) === ArrayInterface.CPUIndex() @test @inferred(device(-1:19)) === ArrayInterface.CPUIndex() @@ -372,7 +371,7 @@ using OffsetArrays @test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing @test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing @test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing - + @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false) @@ -424,7 +423,7 @@ using OffsetArrays @test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing @test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing - + #= @btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2)))) 0.047 ns (0 allocations: 0 bytes) @@ -494,11 +493,11 @@ using OffsetArrays @test @inferred(ArrayInterface.defines_strides(C1)) @test @inferred(ArrayInterface.defines_strides(C2)) @test @inferred(ArrayInterface.defines_strides(C3)) - + @test @inferred(device(C1)) === ArrayInterface.CPUPointer() @test @inferred(device(C2)) === ArrayInterface.CPUPointer() @test @inferred(device(C3)) === ArrayInterface.CPUPointer() - + @test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0) @test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0) @test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0) @@ -510,7 +509,7 @@ using OffsetArrays @test @inferred(contiguous_axis(C1)) === StaticInt(1) @test @inferred(contiguous_axis(C2)) === StaticInt(0) @test @inferred(contiguous_axis(C3)) === StaticInt(2) - + @test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true) @@ -675,7 +674,7 @@ end colormat = reinterpret(reshape, Float64, colors) @test @inferred(ArrayInterface.strides(colormat)) === (StaticInt(1), StaticInt(3)) @test @inferred(ArrayInterface.dense_dims(colormat)) === (True(),True()) - @test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),) + @test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),) @test @inferred(ArrayInterface.dense_dims(view(colormat,:,4:7))) === (True(),True()) @test @inferred(ArrayInterface.dense_dims(view(colormat,2:3,:))) === (True(),False()) @@ -702,7 +701,7 @@ end @test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10) Ac2r_static = reinterpret(reshape, Float64, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6)); @test @inferred(ArrayInterface.strides(Ac2r_static)) === (StaticInt(1), StaticInt(2), StaticInt(10)) - + Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6)); @test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5) Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6)); @@ -711,6 +710,26 @@ end end end +@testset "Reshaped views" begin + # See + # https://github.com/JuliaArrays/ArrayInterface.jl/issues/160 + # https://github.com/JuliaArrays/ArrayInterface.jl/issues/157 + u_base = randn(10, 10) + u_view = view(u_base, 3, :) + u_reshaped_view1 = reshape(u_view, 1, :) + u_reshaped_view2 = reshape(u_view, 2, :) + + @test @inferred(ArrayInterface.defines_strides(u_base)) + @test @inferred(ArrayInterface.defines_strides(u_view)) + @test @inferred(ArrayInterface.defines_strides(u_reshaped_view1)) + @test @inferred(ArrayInterface.defines_strides(u_reshaped_view2)) + + @test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 10) + @test @inferred(ArrayInterface.strides(u_view)) == (10,) + @test @inferred(ArrayInterface.strides(u_reshaped_view1)) == (10, 10) + @test @inferred(ArrayInterface.strides(u_reshaped_view2)) == (10, 20) +end + @test ArrayInterface.can_avx(ArrayInterface.can_avx) == false @testset "can_change_size" begin @@ -842,6 +861,6 @@ end @test @inferred(is_lazy_conjugate(d)) == false e = permutedims(d) @test @inferred(is_lazy_conjugate(e)) == false - + @test @inferred(is_lazy_conjugate([1,2,3]')) == false # We don't care about conj on `<:Real` end