From d8a6ef089e59b99a61d4b04ce113cc25c054c554 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 23 Aug 2024 09:07:26 -0400 Subject: [PATCH] Define == for same-typed fieldvectors --- src/Fields/fieldvector.jl | 4 ++++ test/Fields/unit_field.jl | 4 ++-- test/Fields/utils_field_multi_broadcast_fusion.jl | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index e551f1f7cd..158f317b80 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -455,3 +455,7 @@ Returns `true` if `x == y` recursively. """ rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y) + +# Define == to call rcompare for two fieldvectors of the same +# exact type. +Base.:(==)(x::T, y::T) where {T <: FieldVector} = rcompare(x, y) diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index 37f987701f..123384f328 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -310,10 +310,10 @@ end Y.k.z = 3.0 @test Y.k.z === 3.0 - @test Fields.rcompare(Y, Y) + @test Y == Y Ydc = deepcopy(Y) Ydc.k.z += 1 - @test !Fields.rcompare(Ydc, Y) + @test !(Ydc == Y) # Fields.@rprint_diff(Ydc, Y) s = sprint( Fields._rprint_diff, diff --git a/test/Fields/utils_field_multi_broadcast_fusion.jl b/test/Fields/utils_field_multi_broadcast_fusion.jl index 8238da96b2..35abea33be 100644 --- a/test/Fields/utils_field_multi_broadcast_fusion.jl +++ b/test/Fields/utils_field_multi_broadcast_fusion.jl @@ -103,8 +103,8 @@ function test_kernel!(; fused!, unfused!, X, Y) @testset "Test correctness of $(nameof(typeof(fused!)))" begin Fields.@rprint_diff(X_fused, X_unfused) Fields.@rprint_diff(Y_fused, Y_unfused) - @test Fields.rcompare(X_fused, X_unfused) - @test Fields.rcompare(Y_fused, Y_unfused) + @test X_fused == X_unfused + @test Y_fused == Y_unfused end end