Skip to content

Commit

Permalink
Fix stride(A, i) for 0-dim inputs (#44090)
Browse files Browse the repository at this point in the history
Fixes #44087
  • Loading branch information
N5N3 authored Feb 16, 2022
1 parent 3e5cd3c commit 2d1ea3c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
8 changes: 7 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,13 @@ julia> stride(A,3)
function stride(A::AbstractArray, k::Integer)
st = strides(A)
k ndims(A) && return st[k]
return sum(st .* size(A))
ndims(A) == 0 && return 1
sz = size(A)
s = st[1] * sz[1]
for i in 2:ndims(A)
s += st[i] * sz[i]
end
return s
end

@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)
Expand Down
2 changes: 2 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ StridedMatrix{T} = StridedArray{T,2}
StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}

strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
stride(A::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, k::Integer) =
k ndims(A) ? strides(A)[k] : length(A)

function strides(a::ReshapedReinterpretArray)
ap = parent(a)
Expand Down
13 changes: 13 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,19 @@ end
end
end

@testset "stride for 0 dims array #44087" begin
struct Fill44087 <: AbstractArray{Int,0}
a::Int
end
# `stride` shouldn't work if `strides` is not defined.
@test_throws MethodError stride(Fill44087(1), 1)
# It is intentionally to only check the return type. (The value is somehow arbitrary)
@test stride(fill(1), 1) isa Int
@test stride(reinterpret(Float64, fill(Int64(1))), 1) isa Int
@test stride(reinterpret(reshape, Float64, fill(Int64(1))), 1) isa Int
@test stride(Base.ReshapedArray(fill(1), (), ()), 1) isa Int
end

@testset "to_indices inference (issue #42001 #44059)" begin
@test (@inferred to_indices([], ntuple(Returns(CartesianIndex(1)), 32))) == ntuple(Returns(1), 32)
@test (@inferred to_indices([], ntuple(Returns(CartesianIndices(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32)
Expand Down

0 comments on commit 2d1ea3c

Please sign in to comment.