diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 4f277219a3..b87340799f 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -46,18 +46,44 @@ rmap(fn::F, X) where {F} = fn(X) rmap(fn::F, X::Tuple{}) where {F} = () rmap(fn::F, X::Tuple) where {F} = (rmap(fn, first(X)), rmap(fn, Base.tail(X))...) -rmap(fn::F, X::NamedTuple{names}) where {F, names} = - NamedTuple{names}(rmap(fn, Tuple(X))) +rmap(fn::F, X::NamedTuple) where {F} = + NamedTuple{nt_names(X)}(rmap(fn, Tuple(X))) rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () -rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () -rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y) where {F} = () +rmap(fn::F, X, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple) where {F} = (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) -rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = - NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) +rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = + (rmap(fn, first(X)), rmap(fn, Base.tail(X))...) + +rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = + (rmap(fn, first(Y)), rmap(fn, Base.tail(Y))...) +rmap(fn::F, X, Y::Tuple) where {F} = + (rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...) +rmap(fn::F, X::Tuple, Y) where {F} = + (rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...) + +function rmap(fn::F, X::NamedTuple, Y::NamedTuple) where {F} + @assert nt_names(X) === nt_names(Y) + return NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Tuple(Y))) +end +rmap(fn::F, X::NamedTuple, Y) where {F} = + NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Y)) +rmap(fn::F, X::NamedTuple, Y::Tuple) where {F} = + NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Y)) +rmap(fn::F, X::NamedTuple, Y::Tuple{}) where {F} = + NamedTuple{nt_names(X)}(rmap(fn, Tuple(X))) +rmap(fn::F, X, Y::NamedTuple) where {F} = + NamedTuple{nt_names(Y)}(rmap(fn, X, Tuple(Y))) +rmap(fn::F, X::Tuple, Y::NamedTuple) where {F} = + NamedTuple{nt_names(Y)}(rmap(fn, X, Tuple(Y))) +rmap(fn::F, X::Tuple{}, Y::NamedTuple) where {F} = + NamedTuple{nt_names(Y)}(rmap(fn, Tuple(Y))) + +nt_names(::NamedTuple{names}) where {names} = names rmin(X, Y) = rmap(min, X, Y) rmax(X, Y) = rmap(max, X, Y) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index e77c824c73..1ccc7cf49f 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -2,6 +2,7 @@ using JET using Test using ClimaCore.RecursiveApply +using ClimaCore.Geometry @static if @isdefined(var"@test_opt") # v1.7 and higher @testset "RecursiveApply optimization test" begin @@ -84,3 +85,16 @@ end @inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt) end end + +@testset "NamedTuples and axis tensors" begin + FT = Float64 + nt = (; a = FT(1), b = FT(2)) + uv = Geometry.UVVector(FT(1), FT(2)) + rz = RecursiveApply.rmap(*, nt, uv) + @test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} + @inferred RecursiveApply.rmap(*, nt, uv) + @test rz.a.u == 1 + @test rz.a.v == 2 + @test rz.b.u == 2 + @test rz.b.v == 4 +end diff --git a/test/aqua.jl b/test/aqua.jl index e97d004427..a62f64c068 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -22,7 +22,7 @@ using Aqua for method_ambiguity in ambs @show method_ambiguity end - @test length(ambs) ≤ 16 + @test length(ambs) ≤ 15 end @testset "Aqua tests (additional)" begin