Skip to content

Commit

Permalink
Simplify BitArray concatenation
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Nov 23, 2016
1 parent a261a76 commit 318cf38
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 72 deletions.
11 changes: 4 additions & 7 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1138,17 +1138,15 @@ cat(dims, X...) = cat_t(dims, promote_eltype(X...), X...)
function cat_t(dims, T::Type, X...)
sifter = dims2sift(dims)
shape = cat_shape(sifter, (), map(cat_size, X)...)
return _cat(T, shape, sifter, X...)
end

function _cat(T::Type, shape, sifter, X...)
N = length(shape)
A = cat_similar(X[1], T, shape)

if countnz(sifter) > 1 && T <: Number
fill!(A, zero(T))
end
return _cat(A, shape, sifter, X...)
end

function _cat(A, shape, sifter, X...)
N = length(shape)
offsets = zeros(Int, N)
inds = Vector{UnitRange{Int}}(N)
concat = copy!(zeros(Bool, N), sifter)
Expand All @@ -1163,7 +1161,6 @@ function _cat(T::Type, shape, sifter, X...)
end
A[inds...] = x
end

return A
end

Expand Down
2 changes: 2 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ function vcat{T}(arrays::Vector{T}...)
return arr
end

cat(n::Integer, x::Integer...) = reshape([x...], (ntuple(x->1, n-1)..., length(x)))


## find ##

Expand Down
70 changes: 5 additions & 65 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2215,72 +2215,12 @@ function vcat(A::BitMatrix...)
return B
end

function cat(catdim::Integer, X::Integer...)
reshape([X...], (ones(Int,catdim-1)..., length(X)))
end

# general case, specialized for BitArrays and Integers
function cat(catdim::Integer, X::Union{BitArray, Integer}...)
nargs = length(X)
# using integers results in conversion to Array{Int}
# (except in the all-Bool case)
has_integer = false
for a in X
if isa(a, Integer)
has_integer = true; break
end
end
dimsX = map((a->isa(a,BitArray) ? size(a) : (1,)), X)
ndimsX = map((a->isa(a,BitArray) ? ndims(a) : 1), X)
d_max = maximum(ndimsX)

if catdim > d_max + 1
for i = 1:nargs
dimsX[1] == dimsX[i] ||
throw(DimensionMismatch("all inputs must have same dimensions when concatenating along a higher dimension"))
end
elseif nargs >= 2
for d = 1:d_max
d == catdim && continue
len = d <= ndimsX[1] ? dimsX[1][d] : 1
for i = 2:nargs
len == (d <= ndimsX[i] ? dimsX[i][d] : 1) || throw(DimensionMismatch("mismatch in dimension $d"))
end
end
end

cat_ranges = ntuple(i->(catdim <= ndimsX[i] ? dimsX[i][catdim] : 1), nargs)

function compute_dims(d)
if d == catdim
catdim <= d_max && return sum(cat_ranges)
return nargs
else
d <= ndimsX[1] && return dimsX[1][d]
return 1
end
end

ndimsC = max(catdim, d_max)
dimsC = ntuple(compute_dims, ndimsC)::Tuple{Vararg{Int}}
typeC = promote_type(map(x->isa(x,BitArray) ? eltype(x) : typeof(x), X)...)
if !has_integer || typeC == Bool
C = BitArray(dimsC)
else
C = Array{typeC}(dimsC)
end

range = 1
for k = 1:nargs
nextrange = range + cat_ranges[k]
cat_one = ntuple(i->(i != catdim ? (1:dimsC[i]) : (range:nextrange-1)),
ndimsC)
# note: when C and X are BitArrays, this calls
# the special assign with ranges
C[cat_one...] = X[k]
range = nextrange
end
return C
function cat(dims::Integer, X::Union{BitArray, Bool}...)
sifter = dims2sift(dims)
shape = cat_shape(sifter, (), map(cat_size, X)...)
A = falses(shape)
return _cat(A, shape, sifter, X...)
end

# hvcat -> use fallbacks in abstractarray.jl
Expand Down

0 comments on commit 318cf38

Please sign in to comment.