Skip to content

Commit

Permalink
Merge pull request JuliaArrays#156 from ranocha/hr/fix_contiguous_bat…
Browse files Browse the repository at this point in the history
…ch_size

fix `contiguous_batch_size` for reshaped views
  • Loading branch information
chriselrod authored May 31, 2021
2 parents 9789fa4 + ada0cf2 commit 6a8efa6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R<:Tuple}
return nothing
end
end
_contiguous_batch_size(::StaticInt{-1}, ::R) where {R<:Tuple} = -One()

contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = Zero()
contiguous_batch_size(::Type{BitArray{N}}) where {N} = Zero()
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,12 @@ using OffsetArrays
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1)
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1)
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0)
let u_base = randn(10, 10)
u_view = view(u_base, 3, :)
u_reshaped_view = reshape(u_view, 1, size(u_base, 2))
@test @inferred(contiguous_batch_size(u_view)) === ArrayInterface.StaticInt(-1)
@test @inferred(contiguous_batch_size(u_reshaped_view)) === ArrayInterface.StaticInt(-1)
end

@test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3)
@test @inferred(stride_rank(A)) == (1,2,3)
Expand Down

0 comments on commit 6a8efa6

Please sign in to comment.