Skip to content

Commit

Permalink
Add insertdims method which is inverse to dropdims (JuliaLang#45793)
Browse files Browse the repository at this point in the history
Example:
```julia
julia> a = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

julia> b = insertdims(a, dims=(1,4))
1×2×2×1 Array{Int64, 4}:
[:, :, 1, 1] =
 1  3

[:, :, 2, 1] =
 2  4

julia> b[1,1,1,1] = 5; a
2×2 Matrix{Int64}:
 5  2
 3  4

julia> b = insertdims(a, dims=(1,2))
1×1×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
 5

[:, :, 2, 1] =
 3

[:, :, 1, 2] =
 2

[:, :, 2, 2] =
 4

julia> b = insertdims(a, dims=(1,3))
1×2×1×2 Array{Int64, 4}:
[:, :, 1, 1] =
 5  3

[:, :, 1, 2] =
 2  4
```

---------

Co-authored-by: Neven Sajko <[email protected]>
Co-authored-by: Lilith Orion Hafner <[email protected]>
Co-authored-by: Mark Kittisopikul <[email protected]>
  • Loading branch information
4 people authored and lazarusA committed Aug 17, 2024
1 parent f5b2601 commit 58ff074
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 0 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ New library functions
* The new `isfull(c::Channel)` function can be used to check if `put!(c, some_value)` will block. ([#53159])
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`

New library features
--------------------
Expand Down
64 changes: 64 additions & 0 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,70 @@ function _dropdims(A::AbstractArray, dims::Dims)
end
_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))


"""
insertdims(A; dims)
Inverse of [`dropdims`](@ref); return an array with new singleton dimensions
at every dimension in `dims`.
Repeated dimensions are forbidden and the largest entry in `dims` must be
less than or equal than `ndims(A) + length(dims)`.
The result shares the same underlying data as `A`, such that the
result is mutable if and only if `A` is mutable, and setting elements of one
alters the values of the other.
See also: [`dropdims`](@ref), [`reshape`](@ref), [`vec`](@ref).
# Examples
```jldoctest
julia> x = [1 2 3; 4 5 6]
2×3 Matrix{Int64}:
1 2 3
4 5 6
julia> insertdims(x, dims=3)
2×3×1 Array{Int64, 3}:
[:, :, 1] =
1 2 3
4 5 6
julia> insertdims(x, dims=(1,2,5)) == reshape(x, 1, 1, 2, 3, 1)
true
julia> dropdims(insertdims(x, dims=(1,2,5)), dims=(1,2,5))
2×3 Matrix{Int64}:
1 2 3
4 5 6
```
!!! compat "Julia 1.12"
Requires Julia 1.12 or later.
"""
insertdims(A; dims) = _insertdims(A, dims)
function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M}
for i in eachindex(dims)
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
dims[i] N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end

# acc is a tuple, where the first entry is the final shape
# the second entry off acc is a counter for the axes of A
inds= Base._foldoneto((acc, i) ->
i dims
? ((acc[1]..., Base.OneTo(1)), acc[2])
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
((), 1), Val(N+M))
new_shape = inds[1]
return reshape(A, new_shape)
end
_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))



## Unary operators ##

"""
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ export
indexin,
argmax,
argmin,
insertdims,
invperm,
invpermute!,
isassigned,
Expand Down
1 change: 1 addition & 0 deletions doc/src/base/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Base.parentindices
Base.selectdim
Base.reinterpret
Base.reshape
Base.insertdims
Base.dropdims
Base.vec
Base.SubArray
Expand Down
29 changes: 29 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,35 @@ end
@test_throws ArgumentError dropdims(a, dims=4)
@test_throws ArgumentError dropdims(a, dims=6)


a = rand(8, 7)
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))

@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
@test_throws UndefKeywordError insertdims(a)
@test_throws ArgumentError insertdims(a, dims=0)
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
@test_throws ArgumentError insertdims(a, dims=4)
@test_throws ArgumentError insertdims(a, dims=6)

# insertdims and dropdims are inverses
b = rand(1,1,1,5,1,1,7)
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)]
@test dropdims(insertdims(a; dims); dims) == a
@test insertdims(dropdims(b; dims); dims) == b
end

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6;]
Expand Down

0 comments on commit 58ff074

Please sign in to comment.