diff --git a/NEWS.md b/NEWS.md index b842012bfc33b..c8f7189060cf8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -70,6 +70,7 @@ New library functions * The new `isfull(c::Channel)` function can be used to check if `put!(c, some_value)` will block. ([#53159]) * `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]). * `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]). +* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims` New library features -------------------- diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index a9efc2b87bee4..0f028a0f66729 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -93,6 +93,70 @@ function _dropdims(A::AbstractArray, dims::Dims) end _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),)) + +""" + insertdims(A; dims) + +Inverse of [`dropdims`](@ref); return an array with new singleton dimensions +at every dimension in `dims`. + +Repeated dimensions are forbidden and the largest entry in `dims` must be +less than or equal than `ndims(A) + length(dims)`. + +The result shares the same underlying data as `A`, such that the +result is mutable if and only if `A` is mutable, and setting elements of one +alters the values of the other. + +See also: [`dropdims`](@ref), [`reshape`](@ref), [`vec`](@ref). +# Examples +```jldoctest +julia> x = [1 2 3; 4 5 6] +2×3 Matrix{Int64}: + 1 2 3 + 4 5 6 + +julia> insertdims(x, dims=3) +2×3×1 Array{Int64, 3}: +[:, :, 1] = + 1 2 3 + 4 5 6 + +julia> insertdims(x, dims=(1,2,5)) == reshape(x, 1, 1, 2, 3, 1) +true + +julia> dropdims(insertdims(x, dims=(1,2,5)), dims=(1,2,5)) +2×3 Matrix{Int64}: + 1 2 3 + 4 5 6 +``` + +!!! compat "Julia 1.12" + Requires Julia 1.12 or later. +""" +insertdims(A; dims) = _insertdims(A, dims) +function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M} + for i in eachindex(dims) + 1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1.")) + dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added")) + for j = 1:i-1 + dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique")) + end + end + + # acc is a tuple, where the first entry is the final shape + # the second entry off acc is a counter for the axes of A + inds= Base._foldoneto((acc, i) -> + i ∈ dims + ? ((acc[1]..., Base.OneTo(1)), acc[2]) + : ((acc[1]..., axes(A, acc[2])), acc[2] + 1), + ((), 1), Val(N+M)) + new_shape = inds[1] + return reshape(A, new_shape) +end +_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),)) + + + ## Unary operators ## """ diff --git a/base/exports.jl b/base/exports.jl index 1f0ccdf6b8c36..dbe12f933e597 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -407,6 +407,7 @@ export indexin, argmax, argmin, + insertdims, invperm, invpermute!, isassigned, diff --git a/doc/src/base/arrays.md b/doc/src/base/arrays.md index ccfb23715c476..66fe5c78f1ee6 100644 --- a/doc/src/base/arrays.md +++ b/doc/src/base/arrays.md @@ -138,6 +138,7 @@ Base.parentindices Base.selectdim Base.reinterpret Base.reshape +Base.insertdims Base.dropdims Base.vec Base.SubArray diff --git a/test/arrayops.jl b/test/arrayops.jl index f4bb2dc7372f8..1b81a3e315727 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -308,6 +308,35 @@ end @test_throws ArgumentError dropdims(a, dims=4) @test_throws ArgumentError dropdims(a, dims=6) + + a = rand(8, 7) + @test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7)) + @test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7)) + @test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7)) + @test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1)) + @test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1)) + + @test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3)) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3)) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3)) + @test_throws UndefKeywordError insertdims(a) + @test_throws ArgumentError insertdims(a, dims=0) + @test_throws ArgumentError insertdims(a, dims=(1, 2, 1)) + @test_throws ArgumentError insertdims(a, dims=4) + @test_throws ArgumentError insertdims(a, dims=6) + + # insertdims and dropdims are inverses + b = rand(1,1,1,5,1,1,7) + for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)] + @test dropdims(insertdims(a; dims); dims) == a + @test insertdims(dropdims(b; dims); dims) == b + end + sz = (5,8,7) A = reshape(1:prod(sz),sz...) @test A[2:6] == [2:6;]