Skip to content

Commit

Permalink
Merge pull request #37 from KristofferC/kc/perf
Browse files Browse the repository at this point in the history
  • Loading branch information
goedman authored Nov 24, 2021
2 parents 1236c95 + 02d7671 commit 9ae20ee
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ version = "4.0.7"

[deps]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Calculus = "0.5"
CommonSubexpressions = "0.3"
NaNMath = "0.3"
SpecialFunctions = "0.10, 1.0, 2"
julia = "1"
Expand Down
1 change: 1 addition & 0 deletions src/HyperDualNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module HyperDualNumbers
using SpecialFunctions, LinearAlgebra
import NaNMath
import Calculus
import CommonSubexpressions

include("derivatives_list.jl")
include("hyperdual.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/derivatives_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ which the derivative should be evaluated.
```
"""
symbolic_derivative_list = [
(:sqrt, :(1/2/sqrt(x)), :(-1/4/x^(3/2)))
(:sqrt, :(1/2/sqrt(x)), :(-1/4/(sqrt(x) * x)))
(:cbrt, :(1/3/x^(2/3)), :(-2/9/x^(5/3)))
(:abs2, :(2*x), :(2))
(:inv, :(-1/x^2), :(2/x^3))
Expand Down
43 changes: 28 additions & 15 deletions src/hyperdual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,18 @@ Base.:*(h::Hyper, x::Bool) = x * h
function Base.:*(h₁::Hyper, h₂::Hyper)
x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁)
a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂)
return Hyper(a*x, a*y+b*x, a*z+c*x, a*w+d*x+c*y+b*z)
return Hyper(a*x, muladd(a, y, b*x), muladd(a, z, c*x), muladd(a, w, muladd(d, x, muladd(c, y, b*z))))
end
Base.:*(n::Number, h::Hyper) = Hyper(n*value(h), n*ε₁part(h), n*ε₂part(h), n*ε₁ε₂part(h))
Base.:*(h::Hyper, n::Number) = n * h

Base.one(h::Hyper) = Hyper(one(realpart(h)))

@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{0}) = one(typeof(x))
@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{1}) = x
@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{2}) = x*x
@inline Base.literal_pow(::typeof(^), x::Hyper, ::Val{3}) = x*x*x

function Base.:/(h₁::Hyper, h₂::Hyper)
x, y, z, w = value(h₁), ε₁part(h₁), ε₂part(h₁), ε₁ε₂part(h₁)
a, b, c, d = value(h₂), ε₁part(h₂), ε₂part(h₂), ε₁ε₂part(h₂)
Expand Down Expand Up @@ -370,7 +375,7 @@ function Base.:^(h::Hyper, a::Number)
a^2*x^(a - 2)*y*z - a*x^(a - 2)*y*z + a*w*x^(a - 1))
end

# Below definition is necesssaty to resolve a conflict with the
# Below definition is necesssary to resolve a conflict with the
# definition in MathConstants.jl
function Base.:^(x::Irrational{:ℯ}, h::Hyper)
a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
Expand Down Expand Up @@ -414,20 +419,17 @@ end
to_nanmath(x) = x

for (fsym, dfexp, d²fexp) in symbolic_derivative_list
if isdefined(SpecialFunctions, fsym)
@eval function SpecialFunctions.$(fsym)(h::Hyper)
x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)
end
elseif isdefined(Base, fsym)
@eval function Base.$(fsym)(h::Hyper)
x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)
end
elseif isdefined(Base.Math, fsym)
@eval function Base.Math.$(fsym)(h::Hyper)
mod = isdefined(SpecialFunctions, fsym) ? SpecialFunctions :
isdefined(Base, fsym) ? Base :
isdefined(Base.Math, fsym) ? Base.Math :
nothing
if mod !== nothing && fsym !== :sin && fsym !== :cos # (we define out own sin and cos)
expr = :(Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp))
cse_expr = CommonSubexpressions.cse(expr, warn=false)

@eval function $mod.$(fsym)(h::Hyper)
x, y, z, w = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
Hyper($(fsym)(x), y*$dfexp, z*$dfexp, w*$dfexp + y*z*$d²fexp)
$cse_expr
end
end
# extend corresponding NaNMath methods
Expand All @@ -440,6 +442,17 @@ for (fsym, dfexp, d²fexp) in symbolic_derivative_list
end
end

# Can use sincos for cos and sin
function Base.cos(h::Hyper)
a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
si, co = sincos(a)
return Hyper(co, -si*b, -si*c, -si*d - co*b*c)
end
function Base.sin(h::Hyper)
a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
si, co = sincos(a)
return Hyper(si, co*b, co*c, co*d - si*b*c)
end
# only need to compute exp/cis once (removed exp from derivatives_list)
function Base.exp(h::Hyper)
a, b, c, d = value(h), ε₁part(h), ε₂part(h), ε₁ε₂part(h)
Expand Down

0 comments on commit 9ae20ee

Please sign in to comment.