From ada0cf2665def3f75d03d8d9af6227357f55d414 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Sun, 30 May 2021 15:00:06 +0200 Subject: [PATCH] fix contiguous_batch_size for reshaped views --- src/stridelayout.jl | 1 + test/runtests.jl | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 768bcf279..e20892bca 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -269,6 +269,7 @@ function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R<:Tuple} return nothing end end +_contiguous_batch_size(::StaticInt{-1}, ::R) where {R<:Tuple} = -One() contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = Zero() contiguous_batch_size(::Type{BitArray{N}}) where {N} = Zero() diff --git a/test/runtests.jl b/test/runtests.jl index 92ee3ef16..5588e3004 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -398,6 +398,12 @@ using OffsetArrays @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0) + let u_base = randn(10, 10) + u_view = view(u_base, 3, :) + u_reshaped_view = reshape(u_view, 1, size(u_base, 2)) + @test @inferred(contiguous_batch_size(u_view)) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_batch_size(u_reshaped_view)) === ArrayInterface.StaticInt(-1) + end @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3) @test @inferred(stride_rank(A)) == (1,2,3)