Skip to content

Commit

Permalink
Define rcompare and rprint_diff for FieldVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 16, 2024
1 parent 331571c commit bb4d8b3
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 23 deletions.
77 changes: 77 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,80 @@ import ClimaComms
ClimaComms.array_type(x::FieldVector) = promote_type(
UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))...,
)

function __rprint_diff(
io::IO,
x::T,
y::T;
pc,
xname,
yname,
) where {T <: Union{FieldVector, Field, DataLayouts.AbstractData, NamedTuple}}
for pn in propertynames(x)
pc_full = (pc..., ".", pn)
xi = getproperty(x, pn)
yi = getproperty(y, pn)
__rprint_diff(io, xi, yi; pc = pc_full, xname, yname)
end
end;

function __rprint_diff(io::IO, xi, yi; pc, xname, yname) # assume we can compute difference here
if !(xi == yi)
xs = xname * string(join(pc))
ys = yname * string(join(pc))
println(io, "==================== Difference found:")
println(io, "$xs: ", xi)
println(io, "$ys: ", yi)
println(io, "($xs .- $ys): ", (xi .- yi))
end
return nothing
end

"""
rprint_diff(io::IO, ::T, ::T) where {T <: FieldVector}
rprint_diff(::T, ::T) where {T <: FieldVector}
Recursively print differences in given `FieldVector`.
"""
_rprint_diff(io::IO, x::T, y::T, xname, yname) where {T <: FieldVector} =
__rprint_diff(io, x, y; pc = (), xname, yname)
_rprint_diff(x::T, y::T, xname, yname) where {T <: FieldVector} =
_rprint_diff(stdout, x, y, xname, yname)

"""
@rprint_diff(::T, ::T) where {T <: FieldVector}
Recursively print differences in given `FieldVector`.
"""
macro rprint_diff(x, y)
return :(_rprint_diff(
stdout,
$(esc(x)),
$(esc(y)),
$(string(x)),
$(string(y)),
))
end


# Recursively compare contents of similar fieldvectors
_rcompare(pass, x::T, y::T) where {T <: Field} =
pass && _rcompare(pass, field_values(x), field_values(y))
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
pass && (parent(x) == parent(y))
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)

function _rcompare(pass, x::T, y::T) where {T <: FieldVector}
for pn in propertynames(x)
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
end
return pass
end

"""
rcompare(x::T, y::T) where {T <: FieldVector}
Recursively compare given fieldvectors via `==`.
Returns `true` if `x == y` recursively.
"""
rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)
15 changes: 15 additions & 0 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ end

Y.k.z = 3.0
@test Y.k.z === 3.0

@test Fields.rcompare(Y, Y)
Ydc = deepcopy(Y)
Ydc.k.z += 1
@test !Fields.rcompare(Ydc, Y)
# Fields.@rprint_diff(Ydc, Y)
s = sprint(
Fields._rprint_diff,
Ydc,
Y,
"Ydc",
"Y";
context = IOContext(stdout),
)
@test occursin("==================== Difference found:", s)
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
Expand Down
27 changes: 4 additions & 23 deletions test/Fields/utils_field_multi_broadcast_fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,6 @@ function benchmark_kernel!(f!, X, Y, device)
show(stdout, MIME("text/plain"), trial)
end

function show_diff(A, B)
for pn in propertynames(A)
Ai = getproperty(A, pn)
Bi = getproperty(B, pn)
println("==================== Comparing $pn")
@show Ai
@show Bi
@show abs.(Ai .- Bi)
end
end

function compare(A, B)
pass = true
for pn in propertynames(A)
pass =
pass &&
all(parent(getproperty(A, pn)) .== parent(getproperty(B, pn)))
end
pass || show_diff(A, B)
return pass
end
function test_kernel!(; fused!, unfused!, X, Y)
for pn in propertynames(X)
rand_field!(getproperty(X, pn))
Expand All @@ -122,8 +101,10 @@ function test_kernel!(; fused!, unfused!, X, Y)
unfused!(X_unfused, Y_unfused)
fused!(X_fused, Y_fused)
@testset "Test correctness of $(nameof(typeof(fused!)))" begin
@test compare(X_fused, X_unfused)
@test compare(Y_fused, Y_unfused)
Fields.@rprint_diff(X_fused, X_unfused)
Fields.@rprint_diff(Y_fused, Y_unfused)
@test Fields.rcompare(X_fused, X_unfused)
@test Fields.rcompare(Y_fused, Y_unfused)
end
end

Expand Down

0 comments on commit bb4d8b3

Please sign in to comment.