From 81d8c0c11b7516ec1d262b36736ef46465610ccf Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 16 Jul 2021 13:41:41 -0400 Subject: [PATCH] `hvncat`: Ensure output ndims are >= the ndims of input arrays (#41201) (cherry picked from commit a2f5fe59d72736bfda7b3004f1c4bb58918fa94c) --- base/abstractarray.jl | 127 ++++++++++++++++++++++-------------------- test/abstractarray.jl | 19 +++++++ 2 files changed, 86 insertions(+), 60 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 1fdb441f952c8..f5eb075241dc6 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2156,44 +2156,6 @@ _typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one( _typed_hvncat_0d_only_one() = throw(ArgumentError("a 0-dimensional array may only contain exactly one element")) -function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N} - all(>(0), dims) || - throw(ArgumentError("`dims` argument must contain positive integers")) - A = Array{T, N}(undef, dims...) - lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations - lengthx = length(xs) # Cuts from 3 allocations to 1. - if lengtha != lengthx - throw(ArgumentError("argument count does not match specified shape (expected $lengtha, got $lengthx)")) - end - hvncat_fill!(A, row_first, xs) - return A -end - -function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple) - # putting these in separate functions leads to unnecessary allocations - if row_first - nr, nc = size(A, 1), size(A, 2) - nrc = nr * nc - na = prod(size(A)[3:end]) - k = 1 - for d ∈ 1:na - dd = nrc * (d - 1) - for i ∈ 1:nr - Ai = dd + i - for j ∈ 1:nc - A[Ai] = xs[k] - k += 1 - Ai += nr - end - end - end - else - for k ∈ eachindex(xs) - A[k] = xs[k] - end - end -end - _typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters function _typed_hvncat(::Type{T}, ::Val{N}) where {T, N} @@ -2219,20 +2181,18 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N} throw(ArgumentError("concatenation dimension must be nonnegative")) for a ∈ as ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) || - return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...) + return _typed_hvncat(T, (ntuple(x -> 1, Val(N - 1))..., length(as), 1), false, as...) # the extra 1 is to avoid an infinite cycle end - nd = max(N, ndims(as[1])) + nd = N Ndim = 0 for i ∈ eachindex(as) - a = as[i] - Ndim += size(a, N) - nd = max(nd, ndims(a)) - for d ∈ 1:N-1 - size(a, d) == size(as[1], d) || - throw(ArgumentError("all dimensions of element $i other than $N must be of length 1")) + Ndim += cat_size(as[i], N) + nd = max(nd, cat_ndims(as[i])) + for d ∈ 1:N - 1 + cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i")) end end @@ -2255,16 +2215,15 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N} nd = N Ndim = 0 for i ∈ eachindex(as) - a = as[i] - Ndim += cat_size(a, N) - nd = max(nd, cat_ndims(a)) + Ndim += cat_size(as[i], N) + nd = max(nd, cat_ndims(as[i])) for d ∈ 1:N-1 - cat_size(a, d) == 1 || + cat_size(as[i], d) == 1 || throw(ArgumentError("all dimensions of element $i other than $N must be of length 1")) end end - A = Array{T, nd}(undef, ntuple(x -> 1, N - 1)..., Ndim, ntuple(x -> 1, nd - N)...) + A = Array{T, nd}(undef, ntuple(x -> 1, Val(N - 1))..., Ndim, ntuple(x -> 1, nd - N)...) k = 1 for a ∈ as @@ -2280,7 +2239,6 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N} return A end - # 0-dimensional cases for balanced and unbalanced hvncat method _typed_hvncat(T::Type, ::Tuple{}, ::Bool, x...) = _typed_hvncat(T, Val(0), x...) @@ -2305,7 +2263,51 @@ function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T, end end -function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) where {T, N} +function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N} + all(>(0), dims) || + throw(ArgumentError("`dims` argument must contain positive integers")) + A = Array{T, N}(undef, dims...) + lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations + lengthx = length(xs) # Cuts from 3 allocations to 1. + if lengtha != lengthx + throw(ArgumentError("argument count does not match specified shape (expected $lengtha, got $lengthx)")) + end + hvncat_fill!(A, row_first, xs) + return A +end + +function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple) + # putting these in separate functions leads to unnecessary allocations + if row_first + nr, nc = size(A, 1), size(A, 2) + nrc = nr * nc + na = prod(size(A)[3:end]) + k = 1 + for d ∈ 1:na + dd = nrc * (d - 1) + for i ∈ 1:nr + Ai = dd + i + for j ∈ 1:nc + A[Ai] = xs[k] + k += 1 + Ai += nr + end + end + end + else + for k ∈ eachindex(xs) + A[k] = xs[k] + end + end +end + +function _typed_hvncat(T::Type, dims::NTuple{N, Int}, row_first::Bool, as...) where {N} + # function barrier after calculating the max is necessary for high performance + nd = max(maximum(cat_ndims(a) for a ∈ as), N) + return _typed_hvncat_dims(T, (dims..., ntuple(x -> 1, nd - N)...), row_first, as) +end + +function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as::Tuple) where {T, N} length(as) > 0 || throw(ArgumentError("must have at least one element")) all(>(0), dims) || @@ -2314,16 +2316,14 @@ function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) d1 = row_first ? 2 : 1 d2 = row_first ? 1 : 2 - # discover dimensions - nd = max(N, cat_ndims(as[1])) - outdims = zeros(Int, nd) + outdims = zeros(Int, N) # discover number of rows or columns for i ∈ 1:dims[d1] outdims[d1] += cat_size(as[i], d1) end - currentdims = zeros(Int, nd) + currentdims = zeros(Int, N) blockcount = 0 elementcount = 0 for i ∈ eachindex(as) @@ -2331,11 +2331,11 @@ function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) currentdims[d1] += cat_size(as[i], d1) if currentdims[d1] == outdims[d1] currentdims[d1] = 0 - for d ∈ (d2, 3:nd...) + for d ∈ (d2, 3:N...) currentdims[d] += cat_size(as[i], d) if outdims[d] == 0 # unfixed dimension blockcount += 1 - if blockcount == (d > length(dims) ? 1 : dims[d]) # last expected member of dimension + if blockcount == dims[d] outdims[d] = currentdims[d] currentdims[d] = 0 blockcount = 0 @@ -2378,7 +2378,13 @@ function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...) return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...) end -function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {T, N} +function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N} + # function barrier after calculating the max is necessary for high performance + nd = max(maximum(cat_ndims(a) for a ∈ as), N) + return _typed_hvncat_shape(T, (shape..., ntuple(x -> shape[end], nd - N)...), row_first, as) +end + +function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::Tuple) where {T, N} length(as) > 0 || throw(ArgumentError("must have at least one element")) all(>(0), tuple((shape...)...)) || @@ -2386,6 +2392,7 @@ function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as.. d1 = row_first ? 2 : 1 d2 = row_first ? 1 : 2 + shapev = collect(shape) # saves allocations later all(!isempty, shapev) || throw(ArgumentError("each level of `shape` argument must have at least one value")) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 05f93805953dd..b2b53d33db8ec 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1390,6 +1390,25 @@ using Base: typed_hvncat @test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2) end + # output dimensions are maximum of input dimensions and concatenation dimension + begin + v1 = fill(1, 1, 1) + v2 = fill(1, 1, 1, 1, 1) + v3 = fill(1, 1, 2, 1, 1) + @test [v1 ;;; v2] == [1 ;;; 1 ;;;;] + @test [v2 ;;; v1] == [1 ;;; 1 ;;;;] + @test [v3 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;] + @test [v1 v1 ;;; v3] == [1 1 ;;; 1 1 ;;;;] + @test [v2 v1 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;] + @test [v1 v1 ;;; v1 v2] == [1 1 ;;; 1 1 ;;;;] + @test [v2 ;;; 1] == [1 ;;; 1 ;;;;] + @test [1 ;;; v2] == [1 ;;; 1 ;;;;] + @test [v3 ;;; 1 v1] == [1 1 ;;; 1 1 ;;;;] + @test [v1 1 ;;; v3] == [1 1 ;;; 1 1 ;;;;] + @test [v2 1 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;] + @test [v1 1 ;;; v1 v2] == [1 1 ;;; 1 1 ;;;;] + end + # dims form for v ∈ ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1])) # reject dimension < 0