From 1d64223fb43b358b4bf4c1c0ec95ed19a84cc128 Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:22:51 +0100 Subject: [PATCH 1/6] avoid repeating function body when defining function from rules in different modules --- src/hyperdual.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/hyperdual.jl b/src/hyperdual.jl index 59926b0..e448226 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -414,18 +414,12 @@ end to_nanmath(x) = x for (fsym, dfexp, d²fexp) in symbolic_derivative_list - if isdefined(SpecialFunctions, fsym) - @eval function SpecialFunctions.$(fsym)(h::Hyper) - x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) - end - elseif isdefined(Base, fsym) - @eval function Base.$(fsym)(h::Hyper) - x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) - end - elseif isdefined(Base.Math, fsym) - @eval function Base.Math.$(fsym)(h::Hyper) + mod = isdefined(SpecialFunctions, fsym) ? SpecialFunctions : + isdefined(Base, fsym) ? Base : + isdefined(Base.Math, fsym) ? Base.Math : + nothing + if mod !== nothing + @eval function $mod.$(fsym)(h::Hyper) x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) end From 400e91c8922ff48718bde8fd699a15f05b7428ed Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:31:32 +0100 Subject: [PATCH 2/6] run CSE on expression to avoid computing the same thing many times For example, before this change, ``` code_warntype(cos, Tuple{Hyper256}) ``` showed 2 `cos` evaluations and 3 `sin` evaluations. After, there is only one cos and one sin evaluation. --- Project.toml | 2 ++ src/HyperDualNumbers.jl | 1 + src/hyperdual.jl | 7 +++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2962af9..a56056c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "4.0.7" [deps] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Calculus = "0.5" +CommonSubexpressions = "0.3" NaNMath = "0.3" SpecialFunctions = "0.10, 1.0, 2" julia = "1" diff --git a/src/HyperDualNumbers.jl b/src/HyperDualNumbers.jl index 017d690..a683c50 100644 --- a/src/HyperDualNumbers.jl +++ b/src/HyperDualNumbers.jl @@ -3,6 +3,7 @@ module HyperDualNumbers using SpecialFunctions, LinearAlgebra import NaNMath import Calculus +import CommonSubexpressions include("derivatives_list.jl") include("hyperdual.jl") diff --git a/src/hyperdual.jl b/src/hyperdual.jl index e448226..f7387f8 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -370,7 +370,7 @@ function Base.:^(h::Hyper, a::Number) a^2*x^(a - 2)*y*z - a*x^(a - 2)*y*z + a*w*x^(a - 1)) end -# Below definition is necesssaty to resolve a conflict with the +# Below definition is necesssary to resolve a conflict with the # definition in MathConstants.jl function Base.:^(x::Irrational{:ℯ}, h::Hyper) a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) @@ -419,9 +419,12 @@ for (fsym, dfexp, d²fexp) in symbolic_derivative_list isdefined(Base.Math, fsym) ? Base.Math : nothing if mod !== nothing + expr = :(Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)) + cse_expr = CommonSubexpressions.cse(expr, warn=false) + @eval function $mod.$(fsym)(h::Hyper) x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) + $cse_expr end end # extend corresponding NaNMath methods From cee58646538ef92bafae8433b874ad28e69b162e Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:38:24 +0100 Subject: [PATCH 3/6] avoid computing a `x^3/2` when we already have computed `y=sqrt(x)` and can compute the power as `y*x` --- src/derivatives_list.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derivatives_list.jl b/src/derivatives_list.jl index 49a3400..342aeaf 100644 --- a/src/derivatives_list.jl +++ b/src/derivatives_list.jl @@ -21,7 +21,7 @@ which the derivative should be evaluated. ``` """ symbolic_derivative_list = [ - (:sqrt, :(1/2/sqrt(x)), :(-1/4/x^(3/2))) + (:sqrt, :(1/2/sqrt(x)), :(-1/4/(sqrt(x) * x))) (:cbrt, :(1/3/x^(2/3)), :(-2/9/x^(5/3))) (:abs2, :(2*x), :(2)) (:inv, :(-1/x^2), :(2/x^3)) From 1cbbaf5aade237244e0efe74af22a093b3918f1c Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:38:44 +0100 Subject: [PATCH 4/6] Use `sincos` when computing `sin` or `cos` to take advantage of the improved performance --- src/hyperdual.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/hyperdual.jl b/src/hyperdual.jl index f7387f8..00ce50a 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -418,7 +418,7 @@ for (fsym, dfexp, d²fexp) in symbolic_derivative_list isdefined(Base, fsym) ? Base : isdefined(Base.Math, fsym) ? Base.Math : nothing - if mod !== nothing + if mod !== nothing && fsym !== :sin && fsym !== :cos # (we define out own sin and cos) expr = :(Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)) cse_expr = CommonSubexpressions.cse(expr, warn=false) @@ -437,6 +437,17 @@ for (fsym, dfexp, d²fexp) in symbolic_derivative_list end end +# Can use sincos for cos and sin +function Base.cos(h::Hyper) + a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) + si, co = sincos(a) + return Hyper(co, -si*b, -si*c, -si*d - co*b*c) +end +function Base.sin(h::Hyper) + a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) + si, co = sincos(a) + return Hyper(si, co*b, co*c, co*d - si*b*c) +end # only need to compute exp/cis once (removed exp from derivatives_list) function Base.exp(h::Hyper) a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) From 125e0fdf821bc916a6607828a196ffb0b7198690 Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:44:09 +0100 Subject: [PATCH 5/6] allow literal pows to be expanded --- src/hyperdual.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/hyperdual.jl b/src/hyperdual.jl index 00ce50a..695a606 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -314,6 +314,11 @@ Base.:*(h::Hyper, n::Number) = n * h Base.one(h::Hyper) = Hyper(one(realpart(h))) +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{0}) = one(typeof(x)) +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{1}) = x +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{2}) = x*x +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{3}) = x*x*x + function Base.:/(h₁::Hyper, h₂::Hyper) x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁) a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂) From 02d7671661b2d79099f961e7156b3a9640c2b511 Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 24 Nov 2021 13:49:55 +0100 Subject: [PATCH 6/6] use muladd to allow for fused multiply add in epsilon values --- src/hyperdual.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hyperdual.jl b/src/hyperdual.jl index 695a606..fcd407e 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -307,7 +307,7 @@ Base.:*(h::Hyper, x::Bool) = x * h function Base.:*(h₁::Hyper, h₂::Hyper) x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁) a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂) - return Hyper(a*x, a*y+b*x, a*z+c*x, a*w+d*x+c*y+b*z) + return Hyper(a*x, muladd(a, y, b*x), muladd(a, z, c*x), muladd(a, w, muladd(d, x, muladd(c, y, b*z)))) end Base.:*(n::Number, h::Hyper) = Hyper(n*value(h), n*ε₁part(h), n*ε₂part(h), n*ε₁ε₂part(h)) Base.:*(h::Hyper, n::Number) = n * h