From bb4d8b3b83188fa053434376286c786245ba5b2b Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 16 Jul 2024 13:12:39 -0400 Subject: [PATCH] Define rcompare and rprint_diff for FieldVectors --- src/Fields/fieldvector.jl | 77 +++++++++++++++++++ test/Fields/unit_field.jl | 15 ++++ .../utils_field_multi_broadcast_fusion.jl | 27 +------ 3 files changed, 96 insertions(+), 23 deletions(-) diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index f8b1966578..d9853e5a65 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -361,3 +361,80 @@ import ClimaComms ClimaComms.array_type(x::FieldVector) = promote_type( UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))..., ) + +function __rprint_diff( + io::IO, + x::T, + y::T; + pc, + xname, + yname, +) where {T <: Union{FieldVector, Field, DataLayouts.AbstractData, NamedTuple}} + for pn in propertynames(x) + pc_full = (pc..., ".", pn) + xi = getproperty(x, pn) + yi = getproperty(y, pn) + __rprint_diff(io, xi, yi; pc = pc_full, xname, yname) + end +end; + +function __rprint_diff(io::IO, xi, yi; pc, xname, yname) # assume we can compute difference here + if !(xi == yi) + xs = xname * string(join(pc)) + ys = yname * string(join(pc)) + println(io, "==================== Difference found:") + println(io, "$xs: ", xi) + println(io, "$ys: ", yi) + println(io, "($xs .- $ys): ", (xi .- yi)) + end + return nothing +end + +""" + rprint_diff(io::IO, ::T, ::T) where {T <: FieldVector} + rprint_diff(::T, ::T) where {T <: FieldVector} + +Recursively print differences in given `FieldVector`. +""" +_rprint_diff(io::IO, x::T, y::T, xname, yname) where {T <: FieldVector} = + __rprint_diff(io, x, y; pc = (), xname, yname) +_rprint_diff(x::T, y::T, xname, yname) where {T <: FieldVector} = + _rprint_diff(stdout, x, y, xname, yname) + +""" + @rprint_diff(::T, ::T) where {T <: FieldVector} + +Recursively print differences in given `FieldVector`. +""" +macro rprint_diff(x, y) + return :(_rprint_diff( + stdout, + $(esc(x)), + $(esc(y)), + $(string(x)), + $(string(y)), + )) +end + + +# Recursively compare contents of similar fieldvectors +_rcompare(pass, x::T, y::T) where {T <: Field} = + pass && _rcompare(pass, field_values(x), field_values(y)) +_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} = + pass && (parent(x) == parent(y)) +_rcompare(pass, x::T, y::T) where {T} = pass && (x == y) + +function _rcompare(pass, x::T, y::T) where {T <: FieldVector} + for pn in propertynames(x) + pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn)) + end + return pass +end + +""" + rcompare(x::T, y::T) where {T <: FieldVector} + +Recursively compare given fieldvectors via `==`. +Returns `true` if `x == y` recursively. +""" +rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y) diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index 6cf7a599c9..475dd583ef 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -305,6 +305,21 @@ end Y.k.z = 3.0 @test Y.k.z === 3.0 + + @test Fields.rcompare(Y, Y) + Ydc = deepcopy(Y) + Ydc.k.z += 1 + @test !Fields.rcompare(Ydc, Y) + # Fields.@rprint_diff(Ydc, Y) + s = sprint( + Fields._rprint_diff, + Ydc, + Y, + "Ydc", + "Y"; + context = IOContext(stdout), + ) + @test occursin("==================== Difference found:", s) end # https://github.com/CliMA/ClimaCore.jl/issues/1465 diff --git a/test/Fields/utils_field_multi_broadcast_fusion.jl b/test/Fields/utils_field_multi_broadcast_fusion.jl index 58b55f99ae..8238da96b2 100644 --- a/test/Fields/utils_field_multi_broadcast_fusion.jl +++ b/test/Fields/utils_field_multi_broadcast_fusion.jl @@ -83,27 +83,6 @@ function benchmark_kernel!(f!, X, Y, device) show(stdout, MIME("text/plain"), trial) end -function show_diff(A, B) - for pn in propertynames(A) - Ai = getproperty(A, pn) - Bi = getproperty(B, pn) - println("==================== Comparing $pn") - @show Ai - @show Bi - @show abs.(Ai .- Bi) - end -end - -function compare(A, B) - pass = true - for pn in propertynames(A) - pass = - pass && - all(parent(getproperty(A, pn)) .== parent(getproperty(B, pn))) - end - pass || show_diff(A, B) - return pass -end function test_kernel!(; fused!, unfused!, X, Y) for pn in propertynames(X) rand_field!(getproperty(X, pn)) @@ -122,8 +101,10 @@ function test_kernel!(; fused!, unfused!, X, Y) unfused!(X_unfused, Y_unfused) fused!(X_fused, Y_fused) @testset "Test correctness of $(nameof(typeof(fused!)))" begin - @test compare(X_fused, X_unfused) - @test compare(Y_fused, Y_unfused) + Fields.@rprint_diff(X_fused, X_unfused) + Fields.@rprint_diff(Y_fused, Y_unfused) + @test Fields.rcompare(X_fused, X_unfused) + @test Fields.rcompare(Y_fused, Y_unfused) end end