Skip to content

Commit

Permalink
Improve shape inference in cat
Browse files Browse the repository at this point in the history
`cat` is frequently called with poor inference, since one only has
to concatenate a couple of different container types before inference
punts on the result type. While this does not make the return type
in mixed container types fully inferrable, it does improve the analysis
of the shape.
  • Loading branch information
timholy committed Sep 1, 2020
1 parent 0e082ee commit ceab9ea
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1437,15 +1437,15 @@ vcat(V::AbstractVector{T}...) where {T} = typed_vcat(T, V...)
AbstractVecOrTuple{T} = Union{AbstractVector{<:T}, Tuple{Vararg{T}}}

function _typed_vcat(::Type{T}, V::AbstractVecOrTuple{AbstractVector}) where T
n::Int = 0
n = 0
for Vk in V
n += length(Vk)
n += Int(length(Vk))::Int
end
a = similar(V[1], T, n)
pos = 1
for k=1:length(V)
for k=1:Int(length(V))::Int
Vk = V[k]
p1 = pos+length(Vk)-1
p1 = pos + Int(length(Vk))::Int - 1
a[pos:p1] = Vk
pos = p1+1
end
Expand Down Expand Up @@ -1507,7 +1507,7 @@ function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
pos = 1
for k=1:nargs
Ak = A[k]
p1 = pos+size(Ak,1)-1
p1 = pos+size(Ak,1)::Int-1
B[pos:p1, :] = Ak
pos = p1+1
end
Expand Down Expand Up @@ -1585,17 +1585,18 @@ end
_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)

@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
@inline function _cat_t(dims, T::Type, X...)
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
catdims = dims2cat(dims)
shape = cat_shape(catdims, map(cat_size, X))
shape = cat_shape(catdims, map(cat_size, X)::Tuple{Vararg{Union{Int,Dims}}})::Dims
A = cat_similar(X[1], T, shape)
if count(!iszero, catdims) > 1
if count(!iszero, catdims)::Int > 1
fill!(A, zero(T))
end
return __cat(A, shape, catdims, X...)
end

function __cat(A, shape::NTuple{N}, catdims, X...) where N
function __cat(A, shape::NTuple{M,Int}, catdims, X...) where M
N = M::Int
offsets = zeros(Int, N)
inds = Vector{UnitRange{Int}}(undef, N)
concat = copyto!(zeros(Bool, N), catdims)
Expand Down Expand Up @@ -1702,8 +1703,8 @@ julia> hcat(x, [1; 2; 3])
"""
hcat(X...) = cat(X...; dims=Val(2))

typed_vcat(T::Type, X...) = cat_t(T, X...; dims=Val(1))
typed_hcat(T::Type, X...) = cat_t(T, X...; dims=Val(2))
typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1))
typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2))

"""
cat(A...; dims=dims)
Expand Down

0 comments on commit ceab9ea

Please sign in to comment.