diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 4f277219a3..6479f74bfd 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -51,12 +51,23 @@ rmap(fn::F, X::NamedTuple{names}) where {F, names} = rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () -rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () -rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y) where {F} = () +rmap(fn::F, X, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple) where {F} = (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) + +rmap(fn::F, X, Y::Tuple) where {F} = + (rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...) + +rmap(fn::F, X::Tuple, Y) where {F} = + (rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...) + rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) +rmap(fn::F, X::NamedTuple{names}, Y) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X), Y)) +rmap(fn::F, X, Y::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, X, Tuple(Y))) rmin(X, Y) = rmap(min, X, Y) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index e77c824c73..643c2a972c 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -2,6 +2,7 @@ using JET using Test using ClimaCore.RecursiveApply +using ClimaCore.Geometry @static if @isdefined(var"@test_opt") # v1.7 and higher @testset "RecursiveApply optimization test" begin @@ -84,3 +85,16 @@ end @inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt) end end + +@testset "NamedTuples and axis tensors" begin + FT = Float64 + nt = (; a = FT(1), b = FT(2)) + uv = Geometry.UVVector(FT(1), FT(2)) + @test rz = RecursiveApply.rmap(*, nt, uv) + @test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} + @test @inferred RecursiveApply.rmap(*, nt, uv) + @test rz.a.u == 1 + @test rz.a.v == 2 + @test rz.b.u == 1 + @test rz.b.v == 4 +end