From 209edafd046bdeb54437cc70f08f39d54d3249d4 Mon Sep 17 00:00:00 2001 From: Mike Boyle Date: Tue, 22 Oct 2024 01:05:52 -0400 Subject: [PATCH] Ensure that StaticArrays rules are loaded correctly --- Project.toml | 4 +++- ext/QuaternionicChainRulesCoreExt.jl | 28 +++++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 6814df8..22b3bad 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.5.3" [deps] LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -23,8 +24,8 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" QuaternionicChainRulesCoreExt = "ChainRulesCore" QuaternionicFastDifferentiationExt = "FastDifferentiation" QuaternionicForwardDiffExt = "ForwardDiff" -QuaternionicSymbolicsExt = "Symbolics" QuaternionicLatexifyExt = "Latexify" +QuaternionicSymbolicsExt = "Symbolics" [compat] ChainRulesCore = "1" @@ -34,6 +35,7 @@ GenericLinearAlgebra = "0.3.11" LaTeXStrings = "1" Latexify = "0.15, 0.16" LinearAlgebra = "1" +Pkg = "1.11.0" PrecompileTools = "1.1.1" Random = "1" Requires = "1" diff --git a/ext/QuaternionicChainRulesCoreExt.jl b/ext/QuaternionicChainRulesCoreExt.jl index 44eb1ee..ade868e 100644 --- a/ext/QuaternionicChainRulesCoreExt.jl +++ b/ext/QuaternionicChainRulesCoreExt.jl @@ -1,5 +1,6 @@ module QuaternionicChainRulesCoreExt +using Pkg using Quaternionic import Quaternionic: _sincu, _cossu using StaticArrays @@ -12,10 +13,8 @@ isdefined(Base, :get_extension) ? # It's likely that StaticArrays will have its own ChainRulesCore extension someday, so we # need to check if there is already a ProjectTo defined for SArray. If so, we'll use that. # If not, we'll define one here. -@info [repr(method) for method in methods(ProjectTo) if occursin("SArray", repr(method.sig))] -@show any(method->occursin("SArray", repr(method.sig)), methods(ProjectTo)) -if !any(method->occursin("SArray", repr(method.sig)), methods(ProjectTo)) - @info "Defining ProjectTo for SArray" +staticarrays_info = Pkg.dependencies()[Base.UUID("90137ffa-7385-5640-81b9-e52037218182")] +if staticarrays_info.version < v"1.8.1" # These are ripped from https://github.com/JuliaArrays/StaticArrays.jl/pull/1068 function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray @@ -23,15 +22,22 @@ if !any(method->occursin("SArray", repr(method.sig)), methods(ProjectTo)) return ChainRulesCore.project_type(project)(dz...) end function ProjectTo(x::SArray{S,T}) where {S, T} - return ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) - end - function (project::ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M} - return SArray{project.axes}(dx) + return ProjectTo{SArray}(; + element=ChainRulesCore._eltype_projectto(T), + axes=axes(x), size=StaticArrays.Size(x) + ) end - function rrule(::Type{T}, x::Tuple) where {T<:SArray} + @inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) + (project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx) + function rrule(::Type{T}, x::Tuple) where {T <: SArray} project_x = ProjectTo(x) - Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) - return T(x), Array_pullback + ∇Array(∂y) = (NoTangent(), project_x(∂y)) + return T(x), ∇Array + end + function rrule(::Type{T}, xs::Number...) where {T <: SVector} + project_x = ProjectTo(xs) + ∇Array(∂y) = (NoTangent(), project_x(∂y)...) + return T(xs...), ∇Array end end