Skip to content

Commit

Permalink
Ensure that StaticArrays rules are loaded correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
moble committed Oct 22, 2024
1 parent 75fc1b5 commit 209edaf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.5.3"
[deps]
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -23,8 +24,8 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
QuaternionicChainRulesCoreExt = "ChainRulesCore"
QuaternionicFastDifferentiationExt = "FastDifferentiation"
QuaternionicForwardDiffExt = "ForwardDiff"
QuaternionicSymbolicsExt = "Symbolics"
QuaternionicLatexifyExt = "Latexify"
QuaternionicSymbolicsExt = "Symbolics"

[compat]
ChainRulesCore = "1"
Expand All @@ -34,6 +35,7 @@ GenericLinearAlgebra = "0.3.11"
LaTeXStrings = "1"
Latexify = "0.15, 0.16"
LinearAlgebra = "1"
Pkg = "1.11.0"
PrecompileTools = "1.1.1"
Random = "1"
Requires = "1"
Expand Down
28 changes: 17 additions & 11 deletions ext/QuaternionicChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module QuaternionicChainRulesCoreExt

using Pkg
using Quaternionic
import Quaternionic: _sincu, _cossu
using StaticArrays
Expand All @@ -12,26 +13,31 @@ isdefined(Base, :get_extension) ?
# It's likely that StaticArrays will have its own ChainRulesCore extension someday, so we
# need to check if there is already a ProjectTo defined for SArray. If so, we'll use that.
# If not, we'll define one here.
@info [repr(method) for method in methods(ProjectTo) if occursin("SArray", repr(method.sig))]
@show any(method->occursin("SArray", repr(method.sig)), methods(ProjectTo))
if !any(method->occursin("SArray", repr(method.sig)), methods(ProjectTo))
@info "Defining ProjectTo for SArray"
staticarrays_info = Pkg.dependencies()[Base.UUID("90137ffa-7385-5640-81b9-e52037218182")]
if staticarrays_info.version < v"1.8.1"
# These are ripped from https://github.com/JuliaArrays/StaticArrays.jl/pull/1068
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::SArray)
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return ChainRulesCore.project_type(project)(dz...)
end
function ProjectTo(x::SArray{S,T}) where {S, T}
return ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S)
end
function (project::ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M}
return SArray{project.axes}(dx)
return ProjectTo{SArray}(;
element=ChainRulesCore._eltype_projectto(T),
axes=axes(x), size=StaticArrays.Size(x)
)
end
function rrule(::Type{T}, x::Tuple) where {T<:SArray}
@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx)
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
∇Array(∂y) = (NoTangent(), project_x(∂y))
return T(x), ∇Array
end
function rrule(::Type{T}, xs::Number...) where {T <: SVector}
project_x = ProjectTo(xs)
∇Array(∂y) = (NoTangent(), project_x(∂y)...)
return T(xs...), ∇Array
end
end

Expand Down

0 comments on commit 209edaf

Please sign in to comment.