Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base.length methods for ViewAxis and ComponentAxis types #294

Merged
merged 9 commits into from
Feb 10, 2025
6 changes: 6 additions & 0 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ ViewAxis{Inds,IdxMap,Ax}() where {Inds,IdxMap,Ax} = ViewAxis(Inds, Ax())
ViewAxis(Inds, IdxMap) = ViewAxis(Inds, Axis(IdxMap))
ViewAxis(Inds) = Inds

Base.length(ax::ViewAxis{Inds}) where Inds = length(Inds)
# Fix https://github.com/Deltares/Ribasim/issues/2028
Base.getindex(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx::Integer) where {Inds,IdxMap} = Inds[idx]
Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}) where {Inds,IdxMap} = iterate(Inds)
Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx) where {Inds,IdxMap} = iterate(Inds, idx)

const View = ViewAxis
const NullOrFlatView{Inds,IdxMap} = ViewAxis{Inds,IdxMap,<:NullorFlatAxis}

Expand Down
2 changes: 2 additions & 0 deletions src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ const NullComponentIndex{Idx} = ComponentIndex{Idx, NullAxis}

Base.:(==)(ci1::ComponentIndex, ci2::ComponentIndex) = ci1.idx == ci2.idx && ci1.ax == ci2.ax

Base.length(ci::ComponentIndex) = length(ci.idx)


"""
KeepIndex(idx)
Expand Down
51 changes: 51 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ end
x = ComponentArray(b=1, a=2)
@test merge(NamedTuple(), x) == NamedTuple(x)
@test kw_fun(; x...) == 2

@test length(ViewAxis(2:7, ShapedAxis((2,3)))) == 6
end

@testset "Get" begin
Expand Down Expand Up @@ -385,6 +387,12 @@ end
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))
ax2 = getaxes(ca2)[1]
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))

@test length(ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())) == 1
@test length(ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))) == 2
@test length(ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))) == 4
@test length(ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))) == 3
@test length(ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))) == 7
end

@testset "KeepIndex" begin
Expand Down Expand Up @@ -843,6 +851,49 @@ end
@test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :])
@test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :])
end

# Test fix https://github.com/Deltares/Ribasim/issues/2028
a = range(0.0, 1.0, length=0) |> collect
b = range(0.0, 1.0; length=2) |> collect
c = range(0.0, 1.0, length=3) |> collect
d = range(0.0, 1.0; length=0) |> collect
u = ComponentVector(a=a, b=b, c=c, d=d)

function get_state_index(
idx::Int,
::ComponentVector{A, B, <:Tuple{<:Axis{NT}}},
component_name::Symbol
) where {A, B, NT}
for (comp, range) in pairs(NT)
if comp == component_name
return range[idx]
end
end
return nothing
end

@test_throws BoundsError get_state_index(1, u, :a)
@test_throws BoundsError get_state_index(2, u, :a)
@test get_state_index(1, u, :b) == 1
@test get_state_index(2, u, :b) == 2
@test get_state_index(1, u, :c) == 3
@test get_state_index(2, u, :c) == 4
@test get_state_index(3, u, :c) == 5
@test_throws BoundsError get_state_index(1, u, :d)
@test_throws BoundsError get_state_index(2, u, :d)

# Must be a better way to make sure we can `Base.iterate` the `ViewAxis{UnitRange, Shaped1DAxis}`.
nt = ComponentArrays.indexmap(getaxes(u)[1])
for (i, idx) in enumerate(nt.a)
end
for (i, idx) in enumerate(nt.b)
@test idx == i
end
for (i, idx) in enumerate(nt.c)
@test idx == i + 2
end
for (i, idx) in enumerate(nt.d)
end
end

@testset "axpy! / axpby!" begin
Expand Down
Loading