From e63f62bc0be02d5db2ec81e072adb81f72a81d06 Mon Sep 17 00:00:00 2001 From: Lukas Weber <49278367+lukas-weber@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:48:20 -0500 Subject: [PATCH] tangent_arithmetic: add `norm` for `NoTangent` and `ZeroTangent` --- src/tangent_types/abstract_zero.jl | 2 ++ test/tangent_types/abstract_zero.jl | 3 +++ test/tangent_types/tangent.jl | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index bc1a4239b..77c455c04 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -23,6 +23,8 @@ Base.last(x::AbstractZero) = x Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() +LinearAlgebra.norm(::AbstractZero) = 0 + # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 390b4887b..028d942ea 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -120,6 +120,9 @@ @test dot(ZeroTangent(), dne) == ZeroTangent() @test dot(dne, ZeroTangent()) == ZeroTangent() + @test norm(ZeroTangent()) == 0 + @test norm(ZeroTangent(), 0.4) == 0 + for x in dne @test x === dne end diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 8a5f1666a..b0cb5577e 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -329,11 +329,15 @@ end @test c * NoTangent() == NoTangent() @test dot(NoTangent(), c) == NoTangent() @test dot(c, NoTangent()) == NoTangent() + @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y + @test norm(NoTangent(), Inf) == 0 @test ZeroTangent() * c == ZeroTangent() @test c * ZeroTangent() == ZeroTangent() @test dot(ZeroTangent(), c) == ZeroTangent() @test dot(c, ZeroTangent()) == ZeroTangent() + @test norm(ZeroTangent()) == 0 + @test norm(ZeroTangent(), 0.4) == 0 @test true * c === c @test c * true === c