Skip to content

Commit

Permalink
Type-stabilizing array concatenations (#19387)
Browse files Browse the repository at this point in the history
* Add a type stable array concatenation method

* Add eltype for Dates.Period

* Simplify BitArray concatenation

* Type stable concatenation of sparse arrays
  • Loading branch information
pabloferz authored and stevengj committed Dec 1, 2016
1 parent 2872a2e commit 684dc9c
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 141 deletions.
149 changes: 84 additions & 65 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1095,55 +1095,76 @@ end

## cat: general case

function cat(catdims, X...)
T = promote_type(map(x->isa(x,AbstractArray) ? eltype(x) : typeof(x), X)...)
cat_t(catdims, T, X...)
end

function cat_t(catdims, typeC::Type, X...)
catdims = collect(catdims)
nargs = length(X)
ndimsX = Int[isa(a,AbstractArray) ? ndims(a) : 0 for a in X]
ndimsC = max(maximum(ndimsX), maximum(catdims))
catsizes = zeros(Int,(nargs,length(catdims)))
dims2cat = zeros(Int,ndimsC)
for k = 1:length(catdims)
dims2cat[catdims[k]]=k
end

dimsC = Int[d <= ndimsX[1] ? size(X[1],d) : 1 for d=1:ndimsC]
for k = 1:length(catdims)
catsizes[1,k] = dimsC[catdims[k]]
end
for i = 2:nargs
for d = 1:ndimsC
currentdim = (d <= ndimsX[i] ? size(X[i],d) : 1)
if dims2cat[d] != 0
dimsC[d] += currentdim
catsizes[i,dims2cat[d]] = currentdim
elseif dimsC[d] != currentdim
throw(DimensionMismatch(string("mismatch in dimension ",d,
" (expected ",dimsC[d],
" got ",currentdim,")")))
# helper functions
cat_size(A) = (1,)
cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1
cat_size(A::AbstractArray, d) = size(A, d)

cat_indices(A, d) = OneTo(1)
cat_indices(A::AbstractArray, d) = indices(A, d)

cat_similar(A, T, shape) = Array{T}(shape)
cat_similar(A::AbstractArray, T, shape) = similar(A, T, shape)

cat_shape(dims, shape::Tuple) = shape
@inline cat_shape(dims, shape::Tuple, nshape::Tuple, shapes::Tuple...) =
cat_shape(dims, _cshp(dims, (), shape, nshape), shapes...)

_cshp(::Tuple{}, out, ::Tuple{}, ::Tuple{}) = out
_cshp(::Tuple{}, out, ::Tuple{}, nshape) = (out..., nshape...)
_cshp(dims, out, ::Tuple{}, ::Tuple{}) = (out..., map(b -> 1, dims)...)
@inline _cshp(dims, out, shape, ::Tuple{}) =
_cshp(tail(dims), (out..., shape[1] + dims[1]), tail(shape), ())
@inline _cshp(dims, out, ::Tuple{}, nshape) =
_cshp(tail(dims), (out..., nshape[1]), (), tail(nshape))
@inline function _cshp(::Tuple{}, out, shape, ::Tuple{})
_cs(length(out) + 1, false, shape[1], 1)
_cshp((), (out..., 1), tail(shape), ())
end
@inline function _cshp(::Tuple{}, out, shape, nshape)
next = _cs(length(out) + 1, false, shape[1], nshape[1])
_cshp((), (out..., next), tail(shape), tail(nshape))
end
@inline function _cshp(dims, out, shape, nshape)
next = _cs(length(out) + 1, dims[1], shape[1], nshape[1])
_cshp(tail(dims), (out..., next), tail(shape), tail(nshape))
end

_cs(d, concat, a, b) = concat ? (a + b) : (a == b ? a : throw(DimensionMismatch(string("mismatch in dimension ", d, " (expected ", a, " got ", b, ")"))))

dims2cat{n}(::Type{Val{n}}) = ntuple(i -> (i == n), Val{n})
dims2cat(dims) = ntuple(i -> (i in dims), maximum(dims))

cat(dims, X...) = cat_t(dims, promote_eltype(X...), X...)

function cat_t(dims, T::Type, X...)
catdims = dims2cat(dims)
shape = cat_shape(catdims, (), map(cat_size, X)...)
A = cat_similar(X[1], T, shape)
if T <: Number && countnz(catdims) > 1
fill!(A, zero(T))
end
return _cat(A, shape, catdims, X...)
end

function _cat(A, shape, catdims, X...)
N = length(shape)
offsets = zeros(Int, N)
inds = Vector{UnitRange{Int}}(N)
concat = copy!(zeros(Bool, N), catdims)
for x in X
for i = 1:N
if concat[i]
inds[i] = offsets[i] + cat_indices(x, i)
offsets[i] += cat_size(x, i)
else
inds[i] = 1:shape[i]
end
end
A[inds...] = x
end

C = similar(isa(X[1],AbstractArray) ? X[1] : [X[1]], typeC, tuple(dimsC...))
if length(catdims)>1
fill!(C,0)
end

offsets = zeros(Int,length(catdims))
for i=1:nargs
cat_one = [ dims2cat[d] == 0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]]))
for d=1:ndimsC ]
C[cat_one...] = X[i]
for k = 1:length(catdims)
offsets[k] += catsizes[i,k]
end
end
return C
return A
end

"""
Expand Down Expand Up @@ -1179,7 +1200,7 @@ julia> vcat(c...)
4 5 6
```
"""
vcat(X...) = cat(1, X...)
vcat(X...) = cat(Val{1}, X...)
"""
hcat(A...)
Expand Down Expand Up @@ -1220,30 +1241,28 @@ julia> hcat(c...)
3 6
```
"""
hcat(X...) = cat(2, X...)
hcat(X...) = cat(Val{2}, X...)

typed_vcat(T::Type, X...) = cat_t(1, T, X...)
typed_hcat(T::Type, X...) = cat_t(2, T, X...)
typed_vcat(T::Type, X...) = cat_t(Val{1}, T, X...)
typed_hcat(T::Type, X...) = cat_t(Val{2}, T, X...)

cat{T}(catdims, A::AbstractArray{T}...) = cat_t(catdims, T, A...)

cat(catdims, A::AbstractArray...) = cat_t(catdims, promote_eltype(A...), A...)

# The specializations for 1 and 2 inputs are important
# especially when running with --inline=no, see #11158
vcat(A::AbstractArray) = cat(1, A)
vcat(A::AbstractArray, B::AbstractArray) = cat(1, A, B)
vcat(A::AbstractArray...) = cat(1, A...)
hcat(A::AbstractArray) = cat(2, A)
hcat(A::AbstractArray, B::AbstractArray) = cat(2, A, B)
hcat(A::AbstractArray...) = cat(2, A...)

typed_vcat(T::Type, A::AbstractArray) = cat_t(1, T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(1, T, A, B)
typed_vcat(T::Type, A::AbstractArray...) = cat_t(1, T, A...)
typed_hcat(T::Type, A::AbstractArray) = cat_t(2, T, A)
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(2, T, A, B)
typed_hcat(T::Type, A::AbstractArray...) = cat_t(2, T, A...)
vcat(A::AbstractArray) = cat(Val{1}, A)
vcat(A::AbstractArray, B::AbstractArray) = cat(Val{1}, A, B)
vcat(A::AbstractArray...) = cat(Val{1}, A...)
hcat(A::AbstractArray) = cat(Val{2}, A)
hcat(A::AbstractArray, B::AbstractArray) = cat(Val{2}, A, B)
hcat(A::AbstractArray...) = cat(Val{2}, A...)

typed_vcat(T::Type, A::AbstractArray) = cat_t(Val{1}, T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val{1}, T, A, B)
typed_vcat(T::Type, A::AbstractArray...) = cat_t(Val{1}, T, A...)
typed_hcat(T::Type, A::AbstractArray) = cat_t(Val{2}, T, A)
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val{2}, T, A, B)
typed_hcat(T::Type, A::AbstractArray...) = cat_t(Val{2}, T, A...)

# 2d horizontal and vertical concatenation

Expand Down
2 changes: 2 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ function vcat{T}(arrays::Vector{T}...)
return arr
end

cat(n::Integer, x::Integer...) = reshape([x...], (ntuple(x->1, n-1)..., length(x)))


## find ##

Expand Down
70 changes: 5 additions & 65 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2215,72 +2215,12 @@ function vcat(A::BitMatrix...)
return B
end

function cat(catdim::Integer, X::Integer...)
reshape([X...], (ones(Int,catdim-1)..., length(X)))
end

# general case, specialized for BitArrays and Integers
function cat(catdim::Integer, X::Union{BitArray, Integer}...)
nargs = length(X)
# using integers results in conversion to Array{Int}
# (except in the all-Bool case)
has_integer = false
for a in X
if isa(a, Integer)
has_integer = true; break
end
end
dimsX = map((a->isa(a,BitArray) ? size(a) : (1,)), X)
ndimsX = map((a->isa(a,BitArray) ? ndims(a) : 1), X)
d_max = maximum(ndimsX)

if catdim > d_max + 1
for i = 1:nargs
dimsX[1] == dimsX[i] ||
throw(DimensionMismatch("all inputs must have same dimensions when concatenating along a higher dimension"))
end
elseif nargs >= 2
for d = 1:d_max
d == catdim && continue
len = d <= ndimsX[1] ? dimsX[1][d] : 1
for i = 2:nargs
len == (d <= ndimsX[i] ? dimsX[i][d] : 1) || throw(DimensionMismatch("mismatch in dimension $d"))
end
end
end

cat_ranges = ntuple(i->(catdim <= ndimsX[i] ? dimsX[i][catdim] : 1), nargs)

function compute_dims(d)
if d == catdim
catdim <= d_max && return sum(cat_ranges)
return nargs
else
d <= ndimsX[1] && return dimsX[1][d]
return 1
end
end

ndimsC = max(catdim, d_max)
dimsC = ntuple(compute_dims, ndimsC)::Tuple{Vararg{Int}}
typeC = promote_type(map(x->isa(x,BitArray) ? eltype(x) : typeof(x), X)...)
if !has_integer || typeC == Bool
C = BitArray(dimsC)
else
C = Array{typeC}(dimsC)
end

range = 1
for k = 1:nargs
nextrange = range + cat_ranges[k]
cat_one = ntuple(i->(i != catdim ? (1:dimsC[i]) : (range:nextrange-1)),
ndimsC)
# note: when C and X are BitArrays, this calls
# the special assign with ranges
C[cat_one...] = X[k]
range = nextrange
end
return C
function cat(dims::Integer, X::Union{BitArray, Bool}...)
catdims = dims2cat(dims)
shape = cat_shape(catdims, (), map(cat_size, X)...)
A = falses(shape)
return _cat(A, shape, catdims, X...)
end

# hvcat -> use fallbacks in abstractarray.jl
Expand Down
1 change: 1 addition & 0 deletions base/dates/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ Base.typemin(::Union{DateTime,Type{DateTime}}) = DateTime(-146138511,1,1,0,0,0)
Base.typemax(::Union{Date,Type{Date}}) = Date(252522163911149,12,31)
Base.typemin(::Union{Date,Type{Date}}) = Date(-252522163911150,1,1)
# Date-DateTime promotion, isless, ==
Base.eltype{T<:Period}(::Type{T}) = T
Base.promote_rule(::Type{Date},x::Type{DateTime}) = DateTime
Base.isless(x::Date,y::Date) = isless(value(x),value(y))
Base.isless(x::DateTime,y::DateTime) = isless(value(x),value(y))
Expand Down
12 changes: 4 additions & 8 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3219,12 +3219,8 @@ function vcat(X::SparseMatrixCSC...)
end
end

Tv = eltype(X[1].nzval)
Ti = eltype(X[1].rowval)
for i = 2:length(X)
Tv = promote_type(Tv, eltype(X[i].nzval))
Ti = promote_type(Ti, eltype(X[i].rowval))
end
Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->x.rowval, X)...)

nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
Expand Down Expand Up @@ -3276,8 +3272,8 @@ function hcat(X::SparseMatrixCSC...)
end
n = sum(nX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)
Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->x.rowval, X)...)

colptr = Array{Ti}(n + 1)
nnzX = Int[ nnz(x) for x in X ]
Expand Down
6 changes: 3 additions & 3 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -872,16 +872,16 @@ typealias _TypedDenseConcatGroup{T} Union{Vector{T}, Matrix{T}, _Annotated_Typed

# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays
function cat(catdims, Xin::_SparseConcatGroup...)
X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin]
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
T = promote_eltype(Xin...)
Base.cat_t(catdims, T, X...)
end
function hcat(Xin::_SparseConcatGroup...)
X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin]
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
hcat(X...)
end
function vcat(Xin::_SparseConcatGroup...)
X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin]
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
vcat(X...)
end
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
Expand Down
4 changes: 4 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,10 @@ function test_cat(::Type{TestAbstractArray})

# 18395
@test isa(Any["a" 5; 2//3 1.0][2,1], Rational{Int})

# 13665, 19038
@test @inferred(hcat([1.0 2.0], 3))::Array{Float64,2} == [1.0 2.0 3.0]
@test @inferred(vcat([1.0, 2.0], 3))::Array{Float64,1} == [1.0, 2.0, 3.0]
end

function test_ind2sub(::Type{TestAbstractArray})
Expand Down
3 changes: 3 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1659,3 +1659,6 @@ let X = sparse([1 -1; -1 1])
@test Y / 1 == Y
end
end

# 19304
@inferred hcat(sparse(rand(2,1)), eye(2,2))

0 comments on commit 684dc9c

Please sign in to comment.