Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint and Projections #1068

Closed
wants to merge 5 commits into from
Closed

Conversation

ba2tripleO
Copy link

Solves a few AD related issues:

  • Adjoint of SArray constructor
  • SArray projected to SizedArray in pullback due to this reshape during pullback
  • SMatrix tuple dimensions don't match the matrix dimensions when the above is fixed in this projection

Example

using Zygote
using StaticArrays

u0 = @SVector [1.0f0, 1.0f0]
p = @SVector [1.5f0, 1.0f0, 3.0f0, 1.0f0]

function lotka(u, p)
    du1 = p[1]*u[1] - p[2]*u[1]*u[2]
    du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
    @SVector [du1, du2]
    # @SMatrix [du1 du2 du1; du2 du1 du1]
end

function loss(p)
    u = lotka(u0, p)
    sum(1 .- u)
end

loss(p)

grad = Zygote.gradient(loss, p)
ERROR: Need an adjoint for constructor SVector{2, Float32}. Gradient is of type SizedVector{2, Float32, Vector{Float32}}
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] (::Zygote.Jnew{SVector{2, Float32}, Nothing, false})(Δ::SizedVector{2, Float32, Vector{Float32}})
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\lib\lib.jl:326
  [3] (::Zygote.var"#1928#back#222"{Zygote.Jnew{SVector{2, Float32}, Nothing, false}})(Δ::SizedVector{2, Float32, Vector{Float32}})
    @ Zygote C:\Users\user\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [4] Pullback
    @ C:\Users\user\.julia\packages\StaticArraysCore\gkLqH\src\StaticArraysCore.jl:106 [inlined]
  [5] (::typeof(∂(SVector{2, Float32})))(Δ::SizedVector{2, Float32, Vector{Float32}})
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\user\.julia\dev\StaticArrays\src\convert.jl:163 [inlined]
  [7] Pullback
    @ c:\Users\user\.julia\dev\buffer.jl:668 [inlined]
  [8] (::typeof(∂(lotka)))(Δ::SizedVector{2, Float32, Vector{Float32}})
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\compiler\interface2.jl:0
  [9] Pullback
    @ c:\Users\user\.julia\dev\buffer.jl:673 [inlined]
 [10] (::typeof(∂(loss)))(Δ::Float32)
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\compiler\interface2.jl:0
 [11] (::Zygote.var"#60#61"{typeof(∂(loss))})(Δ::Float32)
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\compiler\interface.jl:41
 [12] gradient(f::Function, args::SVector{4, Float32})
    @ Zygote C:\Users\user\.julia\packages\Zygote\D7j8v\src\compiler\interface.jl:76
 [13] top-level scope
    @ c:\Users\user\.julia\dev\buffer.jl:679

@ba2tripleO
Copy link
Author

What does detect_ambiguities do?

@test length(detect_ambiguities(#=LinearAlgebra, =#StaticArrays)) == allowable_ambiguities

@mateuszbaran
Copy link
Collaborator

What does detect_ambiguities do?

It's a standard Julia function: https://docs.julialang.org/en/v1/stdlib/Test/#Test.detect_ambiguities .

@mateuszbaran
Copy link
Collaborator

By the way, I think it's unlikely any maintainer of StaticArrays.jl will merge a PR that adds a new external dependency. I would suggest putting it in a separate package.

@ba2tripleO
Copy link
Author

ba2tripleO commented Aug 3, 2022

Is it certain that this won't be merged ? I think the alternative would then be to send a pr to ChainRulesCore for the ProjectTo and another pr to ChainRules for the rrule and frule. Also that would add StaticArrays as a dependency to both. @mcabbott any suggestions, as you contribute to both ChainRules and StaticArrays?

@mateuszbaran
Copy link
Collaborator

Well, I can't speak for other people with write access. Anyway there is StaticArraysCore.jl which should be a sufficient dependency to implement these methods.

@hyrodium
Copy link
Collaborator

hyrodium commented Aug 4, 2022

I think it would be okay to add the dependency on ChainRulesCore.jl.

  • The package is lightweight. (The package does not depend on anything other than the standard library.)
  • Its APIs are stable. (The latest version is v1.15.3.)

The ChainRulesCore package provides a light-weight dependency for defining sensitivities for functions in your packages, without you needing to depend on ChainRules itself.

@mcabbott
Copy link
Collaborator

mcabbott commented Aug 5, 2022

ChainRulesCore.jl really shouldn't depend on anything. I wish it were smaller.

I think the policy for ChainRules.jl is only to define rules for things in Base/std.lib., in the hope that everyone else will depend on CRC, and out of fear that otherwise it'll suffer endless mission creep with sort-of important packages... like this one, or FillArrays: JuliaArrays/FillArrays.jl#153 still has nowhere to live, sadly.

@ba2tripleO
Copy link
Author

I have removed the ambiguity causing methods, which includes the frule (1 ambiguity), but that works for me as I really only need rrule . Also adding to @hyrodium 's opinion on ChainRulesCore as a dependency, I think that allowing users to have out of the box AD compatibility has become pretty standard in Julia, so this will be helping in making the feature-set of StaticArrays more complete, lets please get it merged : )

@stevengj
Copy link
Contributor

stevengj commented Sep 9, 2022

Bump — we just ran into this ourselves.

@stevengj
Copy link
Contributor

stevengj commented Sep 9, 2022

(Needs a rebase.)

@stevengj
Copy link
Contributor

stevengj commented Sep 9, 2022

I really think that being differentiable is a reasonable expectation for most packages these days, and ChainRulesCore is a lightweight dependency that is very likely to be already installed.

@thchr
Copy link
Collaborator

thchr commented Sep 9, 2022

It is not all that lightweight though? It takes ~0.3 s to load ChainRulesCore on my laptop versus 0.75 s for StaticArrays itself, so including it as a dependency would add ~40% to StaticArrays' load time.
As context to this, I think there is already some sentiment that StaticArrays loads too slowly as-is.

It's true that it is very likely to be loaded by other packages though since it is e.g., in SpecialFunctions.jl (but I don't think the decision to include it there was entirely clear-cut either, cf. JuliaMath/SpecialFunctions.jl#310).

Comment on lines +21 to +25
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Zygote", "ForwardDiff"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of where these rules end up living, they should be tested with ChainRulesTestUtils and not ad-hoc via Zygote/ForwardDiff/what have you. That library does far more robust testing than most of us would think to write by hand.

Comment on lines +251 to +265
function lotka(u, p, svec=true)
du1 = p[1]*u[1] - p[2]*u[1]*u[2]
du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
if svec
@SVector [du1, du2]
else
@SMatrix [du1 du2 du1; du2 du1 du1]
end
end

#SVector constructor adjoint
function loss(p)
u = lotka(u0, p)
sum(1 .- u)
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this seems a rather not minimal and domain-specific test?

@hyrodium
Copy link
Collaborator

hyrodium commented Sep 9, 2022

It is not all that lightweight though? It takes ~0.3 s to load ChainRulesCore on my laptop versus 0.75 s for StaticArrays itself, so including it as a dependency would add ~40% to StaticArrays' load time.

How did you measure the time? In my environment, the load time doesn't change that much.

On the current master

julia> @time_imports using StaticArrays
      2.7 ms  StaticArraysCore
    659.3 ms  StaticArrays

julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × AMD Ryzen 7 2700X Eight-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
  Threads: 1 on 16 virtual cores

With this PR

julia> @time_imports using StaticArrays
      0.5 ms  Compat
     73.7 ms  ChainRulesCore
      2.1 ms  StaticArraysCore
    651.7 ms  StaticArrays

julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × AMD Ryzen 7 2700X Eight-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
  Threads: 1 on 16 virtual cores

@thchr
Copy link
Collaborator

thchr commented Sep 9, 2022

How did you measure the time? In my environment, the load time doesn't change that much.

I just did plain @time using on the individual packages (i.e., I didn't check out this PR) - but trying this again, I see that it depends on the load-order of StaticArrays and ChainRules, and also seems to depend on whether or not Revise is loaded. With the right ordering or without Revise, I see times closer to yours:

With Revise loaded
tchr@mit:~$ julia

julia> @time using ChainRulesCore
  0.283618 seconds (109.83 k allocations: 6.930 MiB, 78.45% compilation time)

julia> @time using StaticArrays
  0.643347 seconds (2.27 M allocations: 162.173 MiB, 31.34% compilation time: 100% of which was recompilation)
tchr@mit:~$ julia

julia> @time using StaticArrays
  0.679030 seconds (2.06 M allocations: 151.050 MiB, 30.93% compilation time)

julia> @time using ChainRulesCore
  0.065993 seconds (89.80 k allocations: 5.816 MiB)
Without Revise loaded
tchr@mit:~$ julia --startup-file=no

julia> @time using ChainRulesCore
  0.061022 seconds (96.67 k allocations: 8.075 MiB)

julia> @time using StaticArrays
  0.710017 seconds (2.27 M allocations: 162.195 MiB, 2.32% gc time, 31.62% compilation time: 100% of which was recompilation)
tchr@mit:~$ julia --startup-file=no

julia> @time using StaticArrays
  0.476433 seconds (2.05 M allocations: 152.639 MiB)

julia> @time using ChainRulesCore
  0.063983 seconds (87.06 k allocations: 5.672 MiB)

The loading time I originally mentioned might then be mainly due to invalidations between ChainRulesCore and Revise, I suppose?

@mcabbott
Copy link
Collaborator

mcabbott commented Sep 9, 2022

BTW, some experiments with loading time just before CRC 1.0 last year: JuliaDiff/ChainRulesCore.jl#413 . These are just @time using, could perhaps be done more carefully to count the cost of invalidations, Revise, etc.

Maybe a much lighter CRC is worth some thought. I think the downside of moving everything except function rrule end out to CR is essentially testing --- it is never useful without CR being loaded, but it is currently tested without loading CR, which makes a nice tree with no cycles.

It has 148 direct deps so we shouldn't do 2.0 lightly. I wonder how many would actually notice if it were stripped?

Comment on lines +13 to +14
function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M}
return SArray{project.axes}(dx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match what projection does for Array, I think this ought to use the element projector to correct eltype.

The CRC one also checks that the size differs only in trivial ways (i.e. size-1 trailing dimensions), otherwise errors. Sometimes this is helpful for finding bugs in AD rules. I think this will accept any shape with the right length, which most bugs will still hit...


### Project SArray to SArray
function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T}
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anyone is relying on ProjectTo(x).axes actually containing axes(x), then this will be very surprising. Maybe this Tuple{size...} thing ought to have a different name:

Suggested change
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S)
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=axes(x), static_size=S)

@@ -0,0 +1,23 @@
### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here
function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the MWE that hits this?

And might it be simpler just to do something like (p::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) = p(Tuple(dx)), correct the type & then take the same path as other tuples?

@trahflow
Copy link
Contributor

now with julia 1.9 and its extension packages (+ weak dependencies) around the corner, would the main blocker for this be cleared?

@mateuszbaran
Copy link
Collaborator

I think ChainRulesCore as a Julia 1.9-style package extension in StaticArrays would make sense.

@gdalle
Copy link

gdalle commented Aug 10, 2023

I can turn this PR into a package extension if someone can answer these two questions:

  • is the code otherwise operational?
  • what behavior do we want for Julia < 1.9? either nothing or Requires?

@ba2tripleO
Copy link
Author

ba2tripleO commented Aug 10, 2023

It was added to SciMLSensitivity
SciML/SciMLSensitivity.jl@d354ba7
The tests are passing there.

@ChrisRackauckas
Copy link
Member

Fixed in #1224

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants