Skip to content

Commit

Permalink
Fix indexing of DimStacks with nested DimArrays (#892)
Browse files Browse the repository at this point in the history
* indexing fix

* add a test
  • Loading branch information
tiemvanderdeure authored Jan 8, 2025
1 parent de6ca6e commit 052a602
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
17 changes: 7 additions & 10 deletions src/stack/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ for f in (:getindex, :view, :dotview)
@propagate_inbounds function $_dim_f(s::AbstractDimStack)
map(Base.$f, data(s))
end
Base.@assume_effects :foldable @propagate_inbounds function $_dim_f(s::AbstractDimStack{K}, d1::Dimension, ds::Dimension...) where K
Base.@assume_effects :foldable @propagate_inbounds function $_dim_f(
s::AbstractDimStack{K, NT}, d1::Dimension, ds::Dimension...
) where {K, NT <: NamedTuple{K, T}} where T
D = (d1, ds...)
extradims = otherdims(D, dims(s))
length(extradims) > 0 && Dimensions._extradimswarn(extradims)
Expand All @@ -130,26 +132,21 @@ for f in (:getindex, :view, :dotview)
end
newlayers = unrolled_map(f, values(s))
# Decide to rewrap as an AbstractDimStack, or return a scalar
if _any_dimarray(newlayers)
if newlayers isa T
# All scalars, return as-is
NamedTuple{K}(newlayers)
else
# TODO rethink this for many-layered stacks
# Some scalars, re-wrap them as zero dimensional arrays
non_scalar_layers = unrolled_map(values(s), newlayers) do l, nl
nl isa AbstractDimArray ? nl : rebuild(l, fill(nl), ())
end
rebuildsliced(Base.$f, s, NamedTuple{K}(non_scalar_layers), (dims2indices(dims(s), D)))
else
# All scalars, return as-is
NamedTuple{K}(newlayers)
end
end
end
end

@generated function _any_dimarray(v::Union{NamedTuple,Tuple})
any(T -> T <: AbstractDimArray, v.types)
end



#### setindex ####
@propagate_inbounds Base.setindex!(s::AbstractDimStack, xs, I...; kw...) =
Expand Down
10 changes: 10 additions & 0 deletions test/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,13 @@ end
@test rand(mixed) isa @NamedTuple{one::Float64, two::Float32, extradim::Float64}
@test rand(MersenneTwister(), mixed) isa @NamedTuple{one::Float64, two::Float32, extradim::Float64}
end

# https://github.com/rafaqz/DimensionalData.jl/issues/891
@testset "DimStack of nested DimArrays" begin
nested_da = DimArray([da1, da2], Z(1:2))
ds = DimStack((a = nested_da, b = nested_da))
@test ds[1] == (a = da1, b = da1)
@test ds[Z = 1] == (a = da1, b = da1)
@test ds[Z = 1:2] == ds

end

0 comments on commit 052a602

Please sign in to comment.