From f47351465c1ed7b1f9488fb5b51b4c943173a8d7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 1 Nov 2024 15:12:55 -0400 Subject: [PATCH 1/8] fix NaNMath exponentiation --- src/dual.jl | 102 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 66 insertions(+), 36 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index ca3a2cbe..f21782a0 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -552,43 +552,73 @@ end # exponentiation # #----------------# -for f in (:(Base.:^), :(NaNMath.pow)) - @eval begin - @define_binary_dual_op( - $f, - begin - vx, vy = value(x), value(y) - expv = ($f)(vx, vy) - powval = vy * ($f)(vx, vy - 1) - if isconstant(y) - logval = one(expv) - elseif iszero(vx) && vy > 0 - logval = zero(vx) - else - logval = expv * log(vx) - end - new_partials = _mul_partials(partials(x), partials(y), powval, logval) - return Dual{Txy}(expv, new_partials) - end, - begin - v = value(x) - expv = ($f)(v, y) - if y == zero(y) || iszero(partials(x)) - new_partials = zero(partials(x)) - else - new_partials = partials(x) * y * ($f)(v, y - 1) - end - return Dual{Tx}(expv, new_partials) - end, - begin - v = value(y) - expv = ($f)(x, v) - deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x) - return Dual{Ty}(expv, deriv * partials(y)) - end - ) +@define_binary_dual_op( + Base.:^, + begin + vx, vy = value(x), value(y) + expv = (^)(vx, vy) + powval = vy * (^)(vx, vy - 1) + if isconstant(y) + logval = one(expv) + elseif iszero(vx) && vy > 0 + logval = zero(vx) + else + logval = expv * log(vx) + end + new_partials = _mul_partials(partials(x), partials(y), powval, logval) + return Dual{Txy}(expv, new_partials) + end, + begin + v = value(x) + expv = (^)(v, y) + if y == zero(y) || iszero(partials(x)) + new_partials = zero(partials(x)) + else + new_partials = partials(x) * y * (^)(v, y - 1) + end + return Dual{Tx}(expv, new_partials) + end, + begin + v = value(y) + expv = (^)(x, v) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv * log(x) + return Dual{Ty}(expv, deriv * partials(y)) end -end +) + +@define_binary_dual_op( + NaNMath.pow, + begin + vx, vy = value(x), value(y) + expv = NaNMath.pow(vx, vy) + powval = vy * NaNMath.pow(vx, vy - 1) + if isconstant(y) + logval = one(expv) + elseif iszero(vx) && vy > 0 + logval = zero(vx) + else + logval = expv * NaNMath.log(vx) + end + new_partials = _mul_partials(partials(x), partials(y), powval, logval) + return Dual{Txy}(expv, new_partials) + end, + begin + v = value(x) + expv = NaNMath.pow(v, y) + if y == zero(y) || iszero(partials(x)) + new_partials = zero(partials(x)) + else + new_partials = partials(x) * y * NaNMath.pow(v, y - 1) + end + return Dual{Tx}(expv, new_partials) + end, + begin + v = value(y) + expv = NaNMath.pow(x, v) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv*NaNMath.log(x) + return Dual{Ty}(expv, deriv * partials(y)) + end +) @inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} = Dual{T}(one(value(x)), zero(partials(x))) From fd2c647d8eff24bec226999909bb6f00f2b13a0e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 1 Nov 2024 16:43:27 -0400 Subject: [PATCH 2/8] reuse code --- src/dual.jl | 103 +++++++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 65 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index f21782a0..17983c66 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -552,73 +552,46 @@ end # exponentiation # #----------------# -@define_binary_dual_op( - Base.:^, - begin - vx, vy = value(x), value(y) - expv = (^)(vx, vy) - powval = vy * (^)(vx, vy - 1) - if isconstant(y) - logval = one(expv) - elseif iszero(vx) && vy > 0 - logval = zero(vx) - else - logval = expv * log(vx) - end - new_partials = _mul_partials(partials(x), partials(y), powval, logval) - return Dual{Txy}(expv, new_partials) - end, - begin - v = value(x) - expv = (^)(v, y) - if y == zero(y) || iszero(partials(x)) - new_partials = zero(partials(x)) - else - new_partials = partials(x) * y * (^)(v, y - 1) - end - return Dual{Tx}(expv, new_partials) - end, - begin - v = value(y) - expv = (^)(x, v) - deriv = (iszero(x) && v > 0) ? zero(expv) : expv * log(x) - return Dual{Ty}(expv, deriv * partials(y)) +for (f, log) in ((:(Base.:^), :(NaNMath.pow)), (:(NaNMath.pow), :(NaNMath.log))) + @eval begin + @define_binary_dual_op( + $f, + begin + vx, vy = value(x), value(y) + expv = ($f)(vx, vy) + powval = vy * ($f)(vx, vy - 1) + if isconstant(y) + logval = one(expv) + elseif iszero(vx) && vy > 0 + logval = zero(vx) + else + logval = expv * ($log)(vx) + end + new_partials = _mul_partials(partials(x), partials(y), powval, logval) + return Dual{Txy}(expv, new_partials) + end, + begin + v = value(x) + expv = ($f)(v, y) + if y == zero(y) || iszero(partials(x)) + new_partials = zero(partials(x)) + else + new_partials = partials(x) * y * ($f)(v, y - 1) + end + return Dual{Tx}(expv, new_partials) + end, + begin + v = value(y) + expv = ($f)(x, v) + deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x) + return Dual{Ty}(expv, deriv * partials(y)) + end + ) end -) +end + + -@define_binary_dual_op( - NaNMath.pow, - begin - vx, vy = value(x), value(y) - expv = NaNMath.pow(vx, vy) - powval = vy * NaNMath.pow(vx, vy - 1) - if isconstant(y) - logval = one(expv) - elseif iszero(vx) && vy > 0 - logval = zero(vx) - else - logval = expv * NaNMath.log(vx) - end - new_partials = _mul_partials(partials(x), partials(y), powval, logval) - return Dual{Txy}(expv, new_partials) - end, - begin - v = value(x) - expv = NaNMath.pow(v, y) - if y == zero(y) || iszero(partials(x)) - new_partials = zero(partials(x)) - else - new_partials = partials(x) * y * NaNMath.pow(v, y - 1) - end - return Dual{Tx}(expv, new_partials) - end, - begin - v = value(y) - expv = NaNMath.pow(x, v) - deriv = (iszero(x) && v > 0) ? zero(expv) : expv*NaNMath.log(x) - return Dual{Ty}(expv, deriv * partials(y)) - end -) @inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} = Dual{T}(one(value(x)), zero(partials(x))) From c5372878b182a00dbcaf8aa0d8a6c1d906c6d0e3 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 1 Nov 2024 16:48:01 -0400 Subject: [PATCH 3/8] fix --- src/dual.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dual.jl b/src/dual.jl index 17983c66..690257eb 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -552,7 +552,7 @@ end # exponentiation # #----------------# -for (f, log) in ((:(Base.:^), :(NaNMath.pow)), (:(NaNMath.pow), :(NaNMath.log))) +for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log))) @eval begin @define_binary_dual_op( $f, From 9864d7f26106493431d898ef3cc96f86ac1f833d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 1 Nov 2024 17:43:27 -0400 Subject: [PATCH 4/8] add tests --- test/DerivativeTest.jl | 11 +++++++++++ test/GradientTest.jl | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 4b7463c8..3266bf6b 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -93,6 +93,17 @@ end @test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0 end +@testset "exponentiation with NaNMath" begin + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0)) + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0)) + @test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0)) + @test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0)) + + @test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0)) + @test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0 + @test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0) +end + @testset "dimension error for derivative" begin @test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3)) end diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 5adfc8c7..97b9fcc2 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -200,6 +200,18 @@ end @test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) ≈ [1.0 -1.3333333333333337; 0.0 1.666666666666667] end +@testset "gradient for exponential with NaNMath" + @test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1]) + @test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0] + @test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1]) + + @test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1]) + @test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0] + @test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5]) + +end + + @testset "branches in mul!" begin a, b = rand(3,3), rand(3,3) From 23b73512440c3e1df065f454f08c4c2102b4915d Mon Sep 17 00:00:00 2001 From: Jadon Clugston <34165782+jClugstor@users.noreply.github.com> Date: Mon, 4 Nov 2024 09:26:19 -0500 Subject: [PATCH 5/8] Update src/dual.jl Co-authored-by: David Widmann --- src/dual.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 690257eb..7e8ec110 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -590,9 +590,6 @@ for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log))) end end - - - @inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} = Dual{T}(one(value(x)), zero(partials(x))) From 0f72cf97bec7d6087d90703434d8ba525f422e88 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 4 Nov 2024 09:46:29 -0500 Subject: [PATCH 6/8] import NaNMath --- test/DerivativeTest.jl | 1 + test/GradientTest.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 3266bf6b..4de1a6de 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -1,6 +1,7 @@ module DerivativeTest import Calculus +import NaNMath using Test using Random diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 97b9fcc2..39016afa 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -1,6 +1,7 @@ module GradientTest import Calculus +import NaNMath using Test using LinearAlgebra From 3fbc47514d09ca5b05216ad4a44a70336220cb09 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 4 Nov 2024 10:04:48 -0500 Subject: [PATCH 7/8] oops, no begin --- test/GradientTest.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 39016afa..f34b4aa9 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -201,7 +201,7 @@ end @test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) ≈ [1.0 -1.3333333333333337; 0.0 1.666666666666667] end -@testset "gradient for exponential with NaNMath" +@testset "gradient for exponential with NaNMath" begin @test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1]) @test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0] @test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1]) @@ -209,7 +209,6 @@ end @test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1]) @test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0] @test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5]) - end From db4d9f97ace1b16334746e0903d254c2fbd4b3e9 Mon Sep 17 00:00:00 2001 From: Jadon Clugston <34165782+jClugstor@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:15:40 -0500 Subject: [PATCH 8/8] Update test/GradientTest.jl Co-authored-by: David Widmann --- test/GradientTest.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/GradientTest.jl b/test/GradientTest.jl index f34b4aa9..a386c479 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -211,7 +211,6 @@ end @test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5]) end - @testset "branches in mul!" begin a, b = rand(3,3), rand(3,3)