diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 50b83dff86e6b..9c3cb23865dff 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1712,13 +1712,15 @@ end _cs(d, a, b) = (a == b ? a : throw(DimensionMismatch( "mismatch in dimension $d (expected $a got $b)"))) -function dims2cat(::Val{n}) where {n} - n <= 0 && throw(ArgumentError("cat dimension must be a positive integer, but got $n")) - ntuple(i -> (i == n), Val(n)) +function dims2cat(::Val{dims}) where dims + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + ntuple(in(dims), maximum(dims)) end function dims2cat(dims) - if any(dims .<= 0) + if any(≤(0), dims) throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) end ntuple(in(dims), maximum(dims)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 060f1ffa8b8cb..d650cf67ebf11 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -732,6 +732,7 @@ function test_cat(::Type{TestAbstractArray}) @test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2) cat3v(As) = cat(As...; dims=Val(3)) @test @inferred(cat3v(As)) == zeros(2, 2, 2) + @test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4) end function test_ind2sub(::Type{TestAbstractArray})