Skip to content

Commit

Permalink
Fix RecursiveApply on NamedTuples with AxisTensors
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 12, 2023
1 parent 8db1833 commit e790602
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
15 changes: 13 additions & 2 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,23 @@ rmap(fn::F, X::NamedTuple{names}) where {F, names} =

rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y) where {F} = ()
rmap(fn::F, X, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple) where {F} =
(rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...)

rmap(fn::F, X, Y::Tuple) where {F} =
(rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...)

rmap(fn::F, X::Tuple, Y) where {F} =
(rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...)

rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))
rmap(fn::F, X::NamedTuple{names}, Y) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Y))
rmap(fn::F, X, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, X, Tuple(Y)))


rmin(X, Y) = rmap(min, X, Y)
Expand Down
13 changes: 8 additions & 5 deletions test/RecursiveApply/recursive_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ end
end
end

@testset "NamedTuples and vectors" begin
@testset "NamedTuples and axis tensors" begin
FT = Float64
nt = (; a = FT(1), b = FT(2))
uv = Geometry.UVVector(FT(1), FT(2))
@test_broken rz = RecursiveApply.rmap(*, nt, uv)
@test_broken typeof(rz) ==
NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}}
@test_broken @inferred RecursiveApply.rmap(*, nt, uv)
@test rz = RecursiveApply.rmap(*, nt, uv)
@test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}}
@test @inferred RecursiveApply.rmap(*, nt, uv)
@test rz.a.u == 1
@test rz.a.v == 2
@test rz.b.u == 1
@test rz.b.v == 4
end

0 comments on commit e790602

Please sign in to comment.