Skip to content

Commit

Permalink
make cat(As..., dims=Val((1,2,...)) work (JuliaLang#44211)
Browse files Browse the repository at this point in the history
  • Loading branch information
thchr authored and staticfloat committed Mar 2, 2022
1 parent eed15e9 commit de786c6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
10 changes: 6 additions & 4 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1712,13 +1712,15 @@ end
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
"mismatch in dimension $d (expected $a got $b)")))

function dims2cat(::Val{n}) where {n}
n <= 0 && throw(ArgumentError("cat dimension must be a positive integer, but got $n"))
ntuple(i -> (i == n), Val(n))
function dims2cat(::Val{dims}) where dims
if any((0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
end
ntuple(in(dims), maximum(dims))
end

function dims2cat(dims)
if any(dims .<= 0)
if any((0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
end
ntuple(in(dims), maximum(dims))
Expand Down
1 change: 1 addition & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ function test_cat(::Type{TestAbstractArray})
@test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2)
cat3v(As) = cat(As...; dims=Val(3))
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
@test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4)
end

function test_ind2sub(::Type{TestAbstractArray})
Expand Down

0 comments on commit de786c6

Please sign in to comment.