diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 4f277219a3..6479f74bfd 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -51,12 +51,23 @@ rmap(fn::F, X::NamedTuple{names}) where {F, names} = rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () -rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () -rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y) where {F} = () +rmap(fn::F, X, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple) where {F} = (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) + +rmap(fn::F, X, Y::Tuple) where {F} = + (rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...) + +rmap(fn::F, X::Tuple, Y) where {F} = + (rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...) + rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) +rmap(fn::F, X::NamedTuple{names}, Y) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X), Y)) +rmap(fn::F, X, Y::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, X, Tuple(Y))) rmin(X, Y) = rmap(min, X, Y) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index 5e5e2e3b91..643c2a972c 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -86,12 +86,15 @@ end end end -@testset "NamedTuples and vectors" begin +@testset "NamedTuples and axis tensors" begin FT = Float64 nt = (; a = FT(1), b = FT(2)) uv = Geometry.UVVector(FT(1), FT(2)) - @test_broken rz = RecursiveApply.rmap(*, nt, uv) - @test_broken typeof(rz) == - NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} - @test_broken @inferred RecursiveApply.rmap(*, nt, uv) + @test rz = RecursiveApply.rmap(*, nt, uv) + @test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} + @test @inferred RecursiveApply.rmap(*, nt, uv) + @test rz.a.u == 1 + @test rz.a.v == 2 + @test rz.b.u == 1 + @test rz.b.v == 4 end