From 8db1833c377da272282372bffe300deed437b2ec Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 12 Sep 2023 09:20:11 -0700 Subject: [PATCH 1/2] Add broken test for 1453 --- test/RecursiveApply/recursive_apply.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index e77c824c73..5e5e2e3b91 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -2,6 +2,7 @@ using JET using Test using ClimaCore.RecursiveApply +using ClimaCore.Geometry @static if @isdefined(var"@test_opt") # v1.7 and higher @testset "RecursiveApply optimization test" begin @@ -84,3 +85,13 @@ end @inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt) end end + +@testset "NamedTuples and vectors" begin + FT = Float64 + nt = (; a = FT(1), b = FT(2)) + uv = Geometry.UVVector(FT(1), FT(2)) + @test_broken rz = RecursiveApply.rmap(*, nt, uv) + @test_broken typeof(rz) == + NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} + @test_broken @inferred RecursiveApply.rmap(*, nt, uv) +end From e79060296dce46102f7a627a709b6045853ea6e7 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 12 Sep 2023 09:21:34 -0700 Subject: [PATCH 2/2] Fix RecursiveApply on NamedTuples with AxisTensors --- src/RecursiveApply/RecursiveApply.jl | 15 +++++++++++++-- test/RecursiveApply/recursive_apply.jl | 13 ++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 4f277219a3..6479f74bfd 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -51,12 +51,23 @@ rmap(fn::F, X::NamedTuple{names}) where {F, names} = rmap(fn::F, X, Y) where {F} = fn(X, Y) rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () -rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () -rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y) where {F} = () +rmap(fn::F, X, Y::Tuple{}) where {F} = () rmap(fn::F, X::Tuple, Y::Tuple) where {F} = (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) + +rmap(fn::F, X, Y::Tuple) where {F} = + (rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...) + +rmap(fn::F, X::Tuple, Y) where {F} = + (rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...) + rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) +rmap(fn::F, X::NamedTuple{names}, Y) where {F, names} = + NamedTuple{names}(rmap(fn, Tuple(X), Y)) +rmap(fn::F, X, Y::NamedTuple{names}) where {F, names} = + NamedTuple{names}(rmap(fn, X, Tuple(Y))) rmin(X, Y) = rmap(min, X, Y) diff --git a/test/RecursiveApply/recursive_apply.jl b/test/RecursiveApply/recursive_apply.jl index 5e5e2e3b91..643c2a972c 100644 --- a/test/RecursiveApply/recursive_apply.jl +++ b/test/RecursiveApply/recursive_apply.jl @@ -86,12 +86,15 @@ end end end -@testset "NamedTuples and vectors" begin +@testset "NamedTuples and axis tensors" begin FT = Float64 nt = (; a = FT(1), b = FT(2)) uv = Geometry.UVVector(FT(1), FT(2)) - @test_broken rz = RecursiveApply.rmap(*, nt, uv) - @test_broken typeof(rz) == - NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} - @test_broken @inferred RecursiveApply.rmap(*, nt, uv) + @test rz = RecursiveApply.rmap(*, nt, uv) + @test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}} + @test @inferred RecursiveApply.rmap(*, nt, uv) + @test rz.a.u == 1 + @test rz.a.v == 2 + @test rz.b.u == 1 + @test rz.b.v == 4 end