Skip to content

Commit

Permalink
Update ForwardDiff approach (#37)
Browse files Browse the repository at this point in the history
`ForwardDiff` [recently added](JuliaDiff/ForwardDiff.jl#583) support for complex-valued functions of real arguments, which demonstrates the (presumably) right way to support geometric-algebra-valued functions, including quaternion-valued functions.  I was on roughly the right track, but the default method should return 0, and the specific methods of quaternions over `Dual`s should have extracted the `partials` field of each components.  This implements that approach.
  • Loading branch information
moble authored May 19, 2022
1 parent 65a1e0c commit 968190e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
6 changes: 0 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,3 @@ Requires = "1"
StaticArrays = "1"
Symbolics = "0.1, 1, 2, 3, 4"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
21 changes: 19 additions & 2 deletions src/Quaternionic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,25 @@ include("examples.jl")
function __init__()
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
# Let ForwardDiff act naturally on quaternions.
@inline function ForwardDiff.extract_derivative(::Type{T}, y::AbstractQuaternion) where {T}
Quaternion(ForwardDiff.extract_derivative(T, y.components))
# This is cribbed from similar expressions enabling differentiation of complex-valued functions in
# https://github.com/JuliaDiff/ForwardDiff.jl/blob/78c73afd9a21593daf54f61c7d0db67130cf29e1/src/derivative.jl#L83-L88
@inline ForwardDiff.extract_derivative(::Type{T}, y::AbstractQuaternion) where {T} = zero(y)
# Both Quaternion and Rotor, when differentiated, result in a Quaternion
@inline function ForwardDiff.extract_derivative(::Type{T}, y::AbstractQuaternion{TD}) where {T, TD <: ForwardDiff.Dual}
Quaternion(
ForwardDiff.partials(T, y.w, 1),
ForwardDiff.partials(T, y.x, 1),
ForwardDiff.partials(T, y.y, 1),
ForwardDiff.partials(T, y.z, 1)
)
end
# But QuatVec results in a QuatVec
@inline function ForwardDiff.extract_derivative(::Type{T}, y::QuatVec{TD}) where {T, TD <: ForwardDiff.Dual}
QuatVec(
ForwardDiff.partials(T, y.x, 1),
ForwardDiff.partials(T, y.y, 1),
ForwardDiff.partials(T, y.z, 1)
)
end
end

Expand Down
3 changes: 0 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
Quaternionic = "0756cd96-85bf-4b6f-a009-b5012ea7a443"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Expand Down

0 comments on commit 968190e

Please sign in to comment.