diff --git a/Project.toml b/Project.toml index 2962af9..a56056c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "4.0.7" [deps] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Calculus = "0.5" +CommonSubexpressions = "0.3" NaNMath = "0.3" SpecialFunctions = "0.10, 1.0, 2" julia = "1" diff --git a/src/HyperDualNumbers.jl b/src/HyperDualNumbers.jl index 017d690..a683c50 100644 --- a/src/HyperDualNumbers.jl +++ b/src/HyperDualNumbers.jl @@ -3,6 +3,7 @@ module HyperDualNumbers using SpecialFunctions, LinearAlgebra import NaNMath import Calculus +import CommonSubexpressions include("derivatives_list.jl") include("hyperdual.jl") diff --git a/src/derivatives_list.jl b/src/derivatives_list.jl index 49a3400..342aeaf 100644 --- a/src/derivatives_list.jl +++ b/src/derivatives_list.jl @@ -21,7 +21,7 @@ which the derivative should be evaluated. ``` """ symbolic_derivative_list = [ - (:sqrt, :(1/2/sqrt(x)), :(-1/4/x^(3/2))) + (:sqrt, :(1/2/sqrt(x)), :(-1/4/(sqrt(x) * x))) (:cbrt, :(1/3/x^(2/3)), :(-2/9/x^(5/3))) (:abs2, :(2*x), :(2)) (:inv, :(-1/x^2), :(2/x^3)) diff --git a/src/hyperdual.jl b/src/hyperdual.jl index 59926b0..fcd407e 100644 --- a/src/hyperdual.jl +++ b/src/hyperdual.jl @@ -307,13 +307,18 @@ Base.:*(h::Hyper, x::Bool) = x * h function Base.:*(h₁::Hyper, h₂::Hyper) x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁) a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂) - return Hyper(a*x, a*y+b*x, a*z+c*x, a*w+d*x+c*y+b*z) + return Hyper(a*x, muladd(a, y, b*x), muladd(a, z, c*x), muladd(a, w, muladd(d, x, muladd(c, y, b*z)))) end Base.:*(n::Number, h::Hyper) = Hyper(n*value(h), n*ε₁part(h), n*ε₂part(h), n*ε₁ε₂part(h)) Base.:*(h::Hyper, n::Number) = n * h Base.one(h::Hyper) = Hyper(one(realpart(h))) +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{0}) = one(typeof(x)) +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{1}) = x +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{2}) = x*x +@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{3}) = x*x*x + function Base.:/(h₁::Hyper, h₂::Hyper) x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁) a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂) @@ -370,7 +375,7 @@ function Base.:^(h::Hyper, a::Number) a^2*x^(a - 2)*y*z - a*x^(a - 2)*y*z + a*w*x^(a - 1)) end -# Below definition is necesssaty to resolve a conflict with the +# Below definition is necesssary to resolve a conflict with the # definition in MathConstants.jl function Base.:^(x::Irrational{:ℯ}, h::Hyper) a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) @@ -414,20 +419,17 @@ end to_nanmath(x) = x for (fsym, dfexp, d²fexp) in symbolic_derivative_list - if isdefined(SpecialFunctions, fsym) - @eval function SpecialFunctions.$(fsym)(h::Hyper) - x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) - end - elseif isdefined(Base, fsym) - @eval function Base.$(fsym)(h::Hyper) - x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) - end - elseif isdefined(Base.Math, fsym) - @eval function Base.Math.$(fsym)(h::Hyper) + mod = isdefined(SpecialFunctions, fsym) ? SpecialFunctions : + isdefined(Base, fsym) ? Base : + isdefined(Base.Math, fsym) ? Base.Math : + nothing + if mod !== nothing && fsym !== :sin && fsym !== :cos # (we define out own sin and cos) + expr = :(Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)) + cse_expr = CommonSubexpressions.cse(expr, warn=false) + + @eval function $mod.$(fsym)(h::Hyper) x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) - Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp) + $cse_expr end end # extend corresponding NaNMath methods @@ -440,6 +442,17 @@ for (fsym, dfexp, d²fexp) in symbolic_derivative_list end end +# Can use sincos for cos and sin +function Base.cos(h::Hyper) + a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) + si, co = sincos(a) + return Hyper(co, -si*b, -si*c, -si*d - co*b*c) +end +function Base.sin(h::Hyper) + a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h) + si, co = sincos(a) + return Hyper(si, co*b, co*c, co*d - si*b*c) +end # only need to compute exp/cis once (removed exp from derivatives_list) function Base.exp(h::Hyper) a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)