Skip to content

Commit

Permalink
fix strides for reshaped views of abstract vectors, cf. #160
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Jun 2, 2021
1 parent 6a8efa6 commit fe4dc89
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
25 changes: 20 additions & 5 deletions src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end
contiguous_axis(::Type{T}) -> StaticInt{N}
Returns the axis of an array of type `T` containing contiguous data.
If no axis is contiguous, it returns `StaticInt{-1}`.
If no axis is contiguous, it returns a `StaticInt{-1}`.
If unknown, it returns `nothing`.
"""
contiguous_axis(x) = contiguous_axis(typeof(x))
Expand Down Expand Up @@ -297,7 +297,7 @@ contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A}
"""
is_column_major(A) -> True/False
Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
Returns `True()` if elements of `A` are stored in column major order. Otherwise returns `False()`.
"""
is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A))
is_column_major(sr::Nothing, cbs) = False()
Expand All @@ -310,10 +310,11 @@ _is_column_major(sr::R, cbs::StaticInt) where {R} = False()
_is_column_major(sr::R, cbs::Union{StaticInt{0},StaticInt{-1}}) where {R} = is_increasing(sr)

"""
dense_dims(::Type{T}) -> NTuple{N,Bool}
dense_dims(::Type{<:AbstractArray{N}}) -> NTuple{N,Bool}
Returns a tuple of indicators for whether each axis is dense.
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)`
where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
"""
dense_dims(x) = dense_dims(typeof(x))
function dense_dims(::Type{T}) where {T}
Expand Down Expand Up @@ -359,7 +360,7 @@ end
if VERSION v"1.6.0-DEV.1581"
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
ddb = dense_dims(B)
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
end
end

Expand Down Expand Up @@ -464,13 +465,16 @@ julia> A = rand(3,4);
julia> ArrayInterface.strides(A)
(static(1), 3)
```
Additionally, the behavior differs from `Base.strides` for adjoint vectors:
```julia
julia> x = rand(5);
julia> ArrayInterface.strides(x')
(static(1), static(1))
```
This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
Expand All @@ -485,6 +489,17 @@ function strides(x)
return Base.strides(x)
end
end

# Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
# TODO: Should be generalized to reshaped arrays wrapping more general array types
function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
if defines_strides(A)
return size_to_strides(size(A), first(strides(parent(A))))
else
return Base.strides(A)
end
end

@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
Expand Down
49 changes: 34 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,22 @@ end
@testset "Range Interface" begin
@testset "Range Constructors" begin
@test @inferred(StaticInt(1):StaticInt(10)) == 1:10
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
@test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10
@test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
@test @inferred(1:2:StaticInt(10)) == 1:2:10
@test @inferred(1:StaticInt(2):10) == 1:2:10
@test @inferred(StaticInt(1):2:10) == 1:2:10
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
@test @inferred(StaticInt(1):2:10) == 1:2:10
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
@test @inferred(UInt(1):StaticInt(1):StaticInt(10)) === 1:StaticInt(10)
@test @inferred(ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:10)) == 1:10
@test @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) == 1:10

@inferred(ArrayInterface.OptionallyStaticUnitRange(1:10))

@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
@test @inferred(ArrayInterface.OptionallyStaticStepRange(1:10)) == 1:1:10

@test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange(1:2:10)
Expand Down Expand Up @@ -331,7 +331,6 @@ using OffsetArrays
@test @inferred(ArrayInterface.defines_strides(D1))
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))

@test @inferred(device(A)) === ArrayInterface.CPUPointer()
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
Expand Down Expand Up @@ -372,7 +371,7 @@ using OffsetArrays
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing

@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
Expand Down Expand Up @@ -424,7 +423,7 @@ using OffsetArrays
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing


#=
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
0.047 ns (0 allocations: 0 bytes)
Expand Down Expand Up @@ -494,11 +493,11 @@ using OffsetArrays
@test @inferred(ArrayInterface.defines_strides(C1))
@test @inferred(ArrayInterface.defines_strides(C2))
@test @inferred(ArrayInterface.defines_strides(C3))

@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()

@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
Expand All @@ -510,7 +509,7 @@ using OffsetArrays
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
@test @inferred(contiguous_axis(C3)) === StaticInt(2)

@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
Expand Down Expand Up @@ -675,7 +674,7 @@ end
colormat = reinterpret(reshape, Float64, colors)
@test @inferred(ArrayInterface.strides(colormat)) === (StaticInt(1), StaticInt(3))
@test @inferred(ArrayInterface.dense_dims(colormat)) === (True(),True())
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4:7))) === (True(),True())
@test @inferred(ArrayInterface.dense_dims(view(colormat,2:3,:))) === (True(),False())

Expand All @@ -702,7 +701,7 @@ end
@test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10)
Ac2r_static = reinterpret(reshape, Float64, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
@test @inferred(ArrayInterface.strides(Ac2r_static)) === (StaticInt(1), StaticInt(2), StaticInt(10))

Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
Expand All @@ -711,6 +710,26 @@ end
end
end

@testset "Reshaped views" begin
# See
# https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
# https://github.com/JuliaArrays/ArrayInterface.jl/issues/157
u_base = randn(10, 10)
u_view = view(u_base, 3, :)
u_reshaped_view1 = reshape(u_view, 1, :)
u_reshaped_view2 = reshape(u_view, 2, :)

@test @inferred(ArrayInterface.defines_strides(u_base))
@test @inferred(ArrayInterface.defines_strides(u_view))
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view1))
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view2))

@test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 10)
@test @inferred(ArrayInterface.strides(u_view)) == (10,)
@test @inferred(ArrayInterface.strides(u_reshaped_view1)) == (10, 10)
@test @inferred(ArrayInterface.strides(u_reshaped_view2)) == (10, 20)
end

@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false

@testset "can_change_size" begin
Expand Down Expand Up @@ -842,6 +861,6 @@ end
@test @inferred(is_lazy_conjugate(d)) == false
e = permutedims(d)
@test @inferred(is_lazy_conjugate(e)) == false

@test @inferred(is_lazy_conjugate([1,2,3]')) == false # We don't care about conj on `<:Real`
end

0 comments on commit fe4dc89

Please sign in to comment.