From 53690905a67c84d79721435226cae636340c799e Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 17 Sep 2024 09:55:24 -0400 Subject: [PATCH] Add promotion rules for ZeroTangent and NoTangent --- src/tangent_types/abstract_zero.jl | 1 + test/tangent_types/abstract_zero.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index d526abffe..6838168af 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -32,6 +32,7 @@ Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) # (::Type{T})(::AbstractZero, ::AbstractZero...) where {T<:Number} = zero(T) +Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7da7cfadc..94df8e2ac 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -82,6 +82,15 @@ @test convert(Float32, ZeroTangent()) === 0.0f0 @test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im + @test promote_type(ZeroTangent, Bool) == Bool + @test promote_type(Bool, ZeroTangent) == Bool + @test promote_type(ZeroTangent, Int64) == Int64 + @test promote_type(Int64, ZeroTangent) == Int64 + @test promote_type(ZeroTangent, Float32) == Float32 + @test promote_type(Float32, ZeroTangent) == Float32 + @test promote_type(ZeroTangent, ComplexF64) == ComplexF64 + @test promote_type(ComplexF64, ZeroTangent) == ComplexF64 + @test z[1] === z @test z[1:3] === z @test z[1, 2] === z @@ -110,6 +119,15 @@ @test dot(dne, 17.2) == dne @test dot(11.9, dne) == dne + @test promote_type(NoTangent, Bool) == Bool + @test promote_type(Bool, NoTangent) == Bool + @test promote_type(NoTangent, Int64) == Int64 + @test promote_type(Int64, NoTangent) == Int64 + @test promote_type(NoTangent, Float32) == Float32 + @test promote_type(Float32, NoTangent) == Float32 + @test promote_type(NoTangent, ComplexF64) == ComplexF64 + @test promote_type(ComplexF64, NoTangent) == ComplexF64 + @test ZeroTangent() + dne == dne @test dne + ZeroTangent() == dne @test ZeroTangent() - dne == dne