Skip to content

Commit

Permalink
hvncat: Ensure output ndims are >= the ndims of input arrays (#41201)
Browse files Browse the repository at this point in the history
(cherry picked from commit a2f5fe5)
  • Loading branch information
BioTurboNick authored and KristofferC committed Jul 20, 2021
1 parent 87af621 commit 81d8c0c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 60 deletions.
127 changes: 67 additions & 60 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2156,44 +2156,6 @@ _typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one(
_typed_hvncat_0d_only_one() =
throw(ArgumentError("a 0-dimensional array may only contain exactly one element"))

function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N}
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))
A = Array{T, N}(undef, dims...)
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
lengthx = length(xs) # Cuts from 3 allocations to 1.
if lengtha != lengthx
throw(ArgumentError("argument count does not match specified shape (expected $lengtha, got $lengthx)"))
end
hvncat_fill!(A, row_first, xs)
return A
end

function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
# putting these in separate functions leads to unnecessary allocations
if row_first
nr, nc = size(A, 1), size(A, 2)
nrc = nr * nc
na = prod(size(A)[3:end])
k = 1
for d 1:na
dd = nrc * (d - 1)
for i 1:nr
Ai = dd + i
for j 1:nc
A[Ai] = xs[k]
k += 1
Ai += nr
end
end
end
else
for k eachindex(xs)
A[k] = xs[k]
end
end
end

_typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters

function _typed_hvncat(::Type{T}, ::Val{N}) where {T, N}
Expand All @@ -2219,20 +2181,18 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
throw(ArgumentError("concatenation dimension must be nonnegative"))
for a as
ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) ||
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...)
return _typed_hvncat(T, (ntuple(x -> 1, Val(N - 1))..., length(as), 1), false, as...)
# the extra 1 is to avoid an infinite cycle
end

nd = max(N, ndims(as[1]))
nd = N

Ndim = 0
for i eachindex(as)
a = as[i]
Ndim += size(a, N)
nd = max(nd, ndims(a))
for d 1:N-1
size(a, d) == size(as[1], d) ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
Ndim += cat_size(as[i], N)
nd = max(nd, cat_ndims(as[i]))
for d 1:N - 1
cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i"))
end
end

Expand All @@ -2255,16 +2215,15 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
nd = N
Ndim = 0
for i eachindex(as)
a = as[i]
Ndim += cat_size(a, N)
nd = max(nd, cat_ndims(a))
Ndim += cat_size(as[i], N)
nd = max(nd, cat_ndims(as[i]))
for d 1:N-1
cat_size(a, d) == 1 ||
cat_size(as[i], d) == 1 ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
end
end

A = Array{T, nd}(undef, ntuple(x -> 1, N - 1)..., Ndim, ntuple(x -> 1, nd - N)...)
A = Array{T, nd}(undef, ntuple(x -> 1, Val(N - 1))..., Ndim, ntuple(x -> 1, nd - N)...)

k = 1
for a as
Expand All @@ -2280,7 +2239,6 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
return A
end


# 0-dimensional cases for balanced and unbalanced hvncat method

_typed_hvncat(T::Type, ::Tuple{}, ::Bool, x...) = _typed_hvncat(T, Val(0), x...)
Expand All @@ -2305,7 +2263,51 @@ function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T,
end
end

function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) where {T, N}
function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N}
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))
A = Array{T, N}(undef, dims...)
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
lengthx = length(xs) # Cuts from 3 allocations to 1.
if lengtha != lengthx
throw(ArgumentError("argument count does not match specified shape (expected $lengtha, got $lengthx)"))
end
hvncat_fill!(A, row_first, xs)
return A
end

function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
# putting these in separate functions leads to unnecessary allocations
if row_first
nr, nc = size(A, 1), size(A, 2)
nrc = nr * nc
na = prod(size(A)[3:end])
k = 1
for d 1:na
dd = nrc * (d - 1)
for i 1:nr
Ai = dd + i
for j 1:nc
A[Ai] = xs[k]
k += 1
Ai += nr
end
end
end
else
for k eachindex(xs)
A[k] = xs[k]
end
end
end

function _typed_hvncat(T::Type, dims::NTuple{N, Int}, row_first::Bool, as...) where {N}
# function barrier after calculating the max is necessary for high performance
nd = max(maximum(cat_ndims(a) for a as), N)
return _typed_hvncat_dims(T, (dims..., ntuple(x -> 1, nd - N)...), row_first, as)
end

function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as::Tuple) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), dims) ||
Expand All @@ -2314,28 +2316,26 @@ function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...)
d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2

# discover dimensions
nd = max(N, cat_ndims(as[1]))
outdims = zeros(Int, nd)
outdims = zeros(Int, N)

# discover number of rows or columns
for i 1:dims[d1]
outdims[d1] += cat_size(as[i], d1)
end

currentdims = zeros(Int, nd)
currentdims = zeros(Int, N)
blockcount = 0
elementcount = 0
for i eachindex(as)
elementcount += cat_length(as[i])
currentdims[d1] += cat_size(as[i], d1)
if currentdims[d1] == outdims[d1]
currentdims[d1] = 0
for d (d2, 3:nd...)
for d (d2, 3:N...)
currentdims[d] += cat_size(as[i], d)
if outdims[d] == 0 # unfixed dimension
blockcount += 1
if blockcount == (d > length(dims) ? 1 : dims[d]) # last expected member of dimension
if blockcount == dims[d]
outdims[d] = currentdims[d]
currentdims[d] = 0
blockcount = 0
Expand Down Expand Up @@ -2378,14 +2378,21 @@ function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...)
return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...)
end

function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {T, N}
function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N}
# function barrier after calculating the max is necessary for high performance
nd = max(maximum(cat_ndims(a) for a as), N)
return _typed_hvncat_shape(T, (shape..., ntuple(x -> shape[end], nd - N)...), row_first, as)
end

function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::Tuple) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), tuple((shape...)...)) ||
throw(ArgumentError("`shape` argument must consist of positive integers"))

d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2

shapev = collect(shape) # saves allocations later
all(!isempty, shapev) ||
throw(ArgumentError("each level of `shape` argument must have at least one value"))
Expand Down
19 changes: 19 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,25 @@ using Base: typed_hvncat
@test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2)
end

# output dimensions are maximum of input dimensions and concatenation dimension
begin
v1 = fill(1, 1, 1)
v2 = fill(1, 1, 1, 1, 1)
v3 = fill(1, 1, 2, 1, 1)
@test [v1 ;;; v2] == [1 ;;; 1 ;;;;]
@test [v2 ;;; v1] == [1 ;;; 1 ;;;;]
@test [v3 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;]
@test [v1 v1 ;;; v3] == [1 1 ;;; 1 1 ;;;;]
@test [v2 v1 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;]
@test [v1 v1 ;;; v1 v2] == [1 1 ;;; 1 1 ;;;;]
@test [v2 ;;; 1] == [1 ;;; 1 ;;;;]
@test [1 ;;; v2] == [1 ;;; 1 ;;;;]
@test [v3 ;;; 1 v1] == [1 1 ;;; 1 1 ;;;;]
@test [v1 1 ;;; v3] == [1 1 ;;; 1 1 ;;;;]
@test [v2 1 ;;; v1 v1] == [1 1 ;;; 1 1 ;;;;]
@test [v1 1 ;;; v1 v2] == [1 1 ;;; 1 1 ;;;;]
end

# dims form
for v ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1]))
# reject dimension < 0
Expand Down

0 comments on commit 81d8c0c

Please sign in to comment.