From de786c60fc81b079367897b058ff38c3c023bf52 Mon Sep 17 00:00:00 2001 From: Thomas Christensen Date: Thu, 17 Feb 2022 16:31:22 -0500 Subject: [PATCH] make `cat(As..., dims=Val((1,2,...))` work (#44211) --- base/abstractarray.jl | 10 ++++++---- test/abstractarray.jl | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 50b83dff86e6b7..9c3cb23865dffb 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1712,13 +1712,15 @@ end _cs(d, a, b) = (a == b ? a : throw(DimensionMismatch( "mismatch in dimension $d (expected $a got $b)"))) -function dims2cat(::Val{n}) where {n} - n <= 0 && throw(ArgumentError("cat dimension must be a positive integer, but got $n")) - ntuple(i -> (i == n), Val(n)) +function dims2cat(::Val{dims}) where dims + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + ntuple(in(dims), maximum(dims)) end function dims2cat(dims) - if any(dims .<= 0) + if any(≤(0), dims) throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) end ntuple(in(dims), maximum(dims)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 060f1ffa8b8cbe..d650cf67ebf113 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -732,6 +732,7 @@ function test_cat(::Type{TestAbstractArray}) @test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2) cat3v(As) = cat(As...; dims=Val(3)) @test @inferred(cat3v(As)) == zeros(2, 2, 2) + @test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4) end function test_ind2sub(::Type{TestAbstractArray})