Skip to content

Commit

Permalink
Merge #1454
Browse files Browse the repository at this point in the history
1454: Fix and add test for 1453 r=charleskawczynski a=charleskawczynski

This PR fixes and adds a test for #1453.

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski authored Sep 16, 2023
2 parents 334af1f + e790602 commit 7d10e0f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,23 @@ rmap(fn::F, X::NamedTuple{names}) where {F, names} =

rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y) where {F} = ()
rmap(fn::F, X, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple) where {F} =
(rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...)

rmap(fn::F, X, Y::Tuple) where {F} =
(rmap(fn, X, first(Y)), rmap(fn, X, Base.tail(Y))...)

rmap(fn::F, X::Tuple, Y) where {F} =
(rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...)

rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))
rmap(fn::F, X::NamedTuple{names}, Y) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Y))
rmap(fn::F, X, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, X, Tuple(Y)))


rmin(X, Y) = rmap(min, X, Y)
Expand Down
14 changes: 14 additions & 0 deletions test/RecursiveApply/recursive_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using JET
using Test

using ClimaCore.RecursiveApply
using ClimaCore.Geometry

@static if @isdefined(var"@test_opt") # v1.7 and higher
@testset "RecursiveApply optimization test" begin
Expand Down Expand Up @@ -84,3 +85,16 @@ end
@inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt)
end
end

@testset "NamedTuples and axis tensors" begin
FT = Float64
nt = (; a = FT(1), b = FT(2))
uv = Geometry.UVVector(FT(1), FT(2))
@test rz = RecursiveApply.rmap(*, nt, uv)
@test typeof(rz) == NamedTuple{(:a, :b), Tuple{UVVector{FT}, UVVector{FT}}}
@test @inferred RecursiveApply.rmap(*, nt, uv)
@test rz.a.u == 1
@test rz.a.v == 2
@test rz.b.u == 1
@test rz.b.v == 4
end

0 comments on commit 7d10e0f

Please sign in to comment.