Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve cat inferrability #45028

Merged
merged 1 commit into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1716,23 +1716,16 @@ end
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
"mismatch in dimension $d (expected $a got $b)")))

function dims2cat(::Val{dims}) where dims
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
end
ntuple(in(dims), maximum(dims))
end

dims2cat(::Val{dims}) where dims = dims2cat(dims)
function dims2cat(dims)
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
end
ntuple(in(dims), maximum(dims))
end

_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
_cat(dims, X...) = _cat_t(dims, promote_eltypeof(X...), X...)

@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
catdims = dims2cat(dims)
shape = cat_size_shape(catdims, X...)
Expand All @@ -1742,6 +1735,9 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
end
return __cat(A, shape, catdims, X...)
end
# this version of `cat_t` is not very kind for inference and so its usage should be avoided,
# nevertheless it is here just for compat after https://github.com/JuliaLang/julia/pull/45028
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)

# Why isn't this called `__cat!`?
__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
Expand Down Expand Up @@ -1880,8 +1876,8 @@ julia> reduce(hcat, vs)
"""
hcat(X...) = cat(X...; dims=Val(2))

typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1))
typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2))
typed_vcat(::Type{T}, X...) where T = _cat_t(Val(1), T, X...)
typed_hcat(::Type{T}, X...) where T = _cat_t(Val(2), T, X...)

"""
cat(A...; dims)
Expand Down Expand Up @@ -1917,7 +1913,8 @@ julia> cat(true, trues(2,2), trues(4)', dims=(1,2))
```
"""
@inline cat(A...; dims) = _cat(dims, A...)
_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims)
# `@constprop :aggressive` allows `catdims` to be propagated as constant improving return type inference
@constprop :aggressive _cat(catdims, A::AbstractArray{T}...) where {T} = _cat_t(catdims, T, A...)

# The specializations for 1 and 2 inputs are important
# especially when running with --inline=no, see #11158
Expand All @@ -1928,12 +1925,12 @@ hcat(A::AbstractArray) = cat(A; dims=Val(2))
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))

typed_vcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(1))
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(1))
typed_vcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(1))
typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2))
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2))
typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2))
typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B)
typed_vcat(T::Type, A::AbstractArray...) = _cat_t(Val(1), T, A...)
typed_hcat(T::Type, A::AbstractArray) = _cat_t(Val(2), T, A)
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(2), T, A, B)
typed_hcat(T::Type, A::AbstractArray...) = _cat_t(Val(2), T, A...)

# 2d horizontal and vertical concatenation

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,14 @@ const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpo

promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix

Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(promote_eltype(xs...), xs...; dims=dims)
Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...)
vcat(A::Vector...) = Base.typed_vcat(promote_eltype(A...), A...)
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...)
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims)
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)
Expand Down
4 changes: 4 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,10 @@ function test_cat(::Type{TestAbstractArray})
cat3v(As) = cat(As...; dims=Val(3))
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
@test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4)

r = rand(Float32, 56, 56, 64, 1);
f(r) = cat(r, r, dims=(3,))
@inferred f(r);
end

function test_ind2sub(::Type{TestAbstractArray})
Expand Down
15 changes: 3 additions & 12 deletions test/ambiguous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,11 @@ using LinearAlgebra, SparseArrays, SuiteSparse
# not using isempty so this prints more information when it fails
@testset "detect_ambiguities" begin
let ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in detect_ambiguities(Core, Base; recursive=true, ambiguous_bottom=false, allowed_undefineds)))
@test isempty(ambig)
expect = []
good = true
while !isempty(ambig)
sigs = pop!(ambig)
i = findfirst(==(sigs), expect)
if i === nothing
println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))")
good = false
continue
end
deleteat!(expect, i)
for (sig1, sig2) in ambig
@test sig1 === sig2 # print this ambiguity
good = false
end
@test isempty(expect)
@test good
end

Expand Down