-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoint and Projections #1068
Adjoint and Projections #1068
Conversation
What does StaticArrays.jl/test/ambiguities.jl Line 11 in 76ce2c5
|
It's a standard Julia function: https://docs.julialang.org/en/v1/stdlib/Test/#Test.detect_ambiguities . |
By the way, I think it's unlikely any maintainer of StaticArrays.jl will merge a PR that adds a new external dependency. I would suggest putting it in a separate package. |
Is it certain that this won't be merged ? I think the alternative would then be to send a pr to ChainRulesCore for the |
Well, I can't speak for other people with write access. Anyway there is |
I think it would be okay to add the dependency on ChainRulesCore.jl.
|
ChainRulesCore.jl really shouldn't depend on anything. I wish it were smaller. I think the policy for ChainRules.jl is only to define rules for things in Base/std.lib., in the hope that everyone else will depend on CRC, and out of fear that otherwise it'll suffer endless mission creep with sort-of important packages... like this one, or FillArrays: JuliaArrays/FillArrays.jl#153 still has nowhere to live, sadly. |
I have removed the ambiguity causing methods, which includes the |
Bump — we just ran into this ourselves. |
(Needs a rebase.) |
I really think that being differentiable is a reasonable expectation for most packages these days, and ChainRulesCore is a lightweight dependency that is very likely to be already installed. |
It is not all that lightweight though? It takes ~0.3 s to load ChainRulesCore on my laptop versus 0.75 s for StaticArrays itself, so including it as a dependency would add ~40% to StaticArrays' load time. It's true that it is very likely to be loaded by other packages though since it is e.g., in SpecialFunctions.jl (but I don't think the decision to include it there was entirely clear-cut either, cf. JuliaMath/SpecialFunctions.jl#310). |
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
|
||
[targets] | ||
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"] | ||
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Zygote", "ForwardDiff"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless of where these rules end up living, they should be tested with ChainRulesTestUtils and not ad-hoc via Zygote/ForwardDiff/what have you. That library does far more robust testing than most of us would think to write by hand.
function lotka(u, p, svec=true) | ||
du1 = p[1]*u[1] - p[2]*u[1]*u[2] | ||
du2 = -p[3]*u[2] + p[4]*u[1]*u[2] | ||
if svec | ||
@SVector [du1, du2] | ||
else | ||
@SMatrix [du1 du2 du1; du2 du1 du1] | ||
end | ||
end | ||
|
||
#SVector constructor adjoint | ||
function loss(p) | ||
u = lotka(u0, p) | ||
sum(1 .- u) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this seems a rather not minimal and domain-specific test?
How did you measure the time? In my environment, the load time doesn't change that much. On the current master julia> @time_imports using StaticArrays
2.7 ms StaticArraysCore
659.3 ms StaticArrays
julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 16 × AMD Ryzen 7 2700X Eight-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
Threads: 1 on 16 virtual cores With this PR julia> @time_imports using StaticArrays
0.5 ms Compat
73.7 ms ChainRulesCore
2.1 ms StaticArraysCore
651.7 ms StaticArrays
julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 16 × AMD Ryzen 7 2700X Eight-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
Threads: 1 on 16 virtual cores |
I just did plain With Revise loadedtchr@mit:~$ julia
julia> @time using ChainRulesCore
0.283618 seconds (109.83 k allocations: 6.930 MiB, 78.45% compilation time)
julia> @time using StaticArrays
0.643347 seconds (2.27 M allocations: 162.173 MiB, 31.34% compilation time: 100% of which was recompilation) tchr@mit:~$ julia
julia> @time using StaticArrays
0.679030 seconds (2.06 M allocations: 151.050 MiB, 30.93% compilation time)
julia> @time using ChainRulesCore
0.065993 seconds (89.80 k allocations: 5.816 MiB) Without Revise loadedtchr@mit:~$ julia --startup-file=no
julia> @time using ChainRulesCore
0.061022 seconds (96.67 k allocations: 8.075 MiB)
julia> @time using StaticArrays
0.710017 seconds (2.27 M allocations: 162.195 MiB, 2.32% gc time, 31.62% compilation time: 100% of which was recompilation) tchr@mit:~$ julia --startup-file=no
julia> @time using StaticArrays
0.476433 seconds (2.05 M allocations: 152.639 MiB)
julia> @time using ChainRulesCore
0.063983 seconds (87.06 k allocations: 5.672 MiB) The loading time I originally mentioned might then be mainly due to invalidations between ChainRulesCore and Revise, I suppose? |
BTW, some experiments with loading time just before CRC 1.0 last year: JuliaDiff/ChainRulesCore.jl#413 . These are just Maybe a much lighter CRC is worth some thought. I think the downside of moving everything except It has 148 direct deps so we shouldn't do 2.0 lightly. I wonder how many would actually notice if it were stripped? |
function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M} | ||
return SArray{project.axes}(dx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To match what projection does for Array
, I think this ought to use the element
projector to correct eltype.
The CRC one also checks that the size differs only in trivial ways (i.e. size-1 trailing dimensions), otherwise errors. Sometimes this is helpful for finding bugs in AD rules. I think this will accept any shape with the right length, which most bugs will still hit...
|
||
### Project SArray to SArray | ||
function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T} | ||
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone is relying on ProjectTo(x).axes
actually containing axes(x)
, then this will be very surprising. Maybe this Tuple{size...}
thing ought to have a different name:
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) | |
return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=axes(x), static_size=S) |
@@ -0,0 +1,23 @@ | |||
### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here | |||
function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the MWE that hits this?
And might it be simpler just to do something like (p::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) = p(Tuple(dx))
, correct the type & then take the same path as other tuples?
now with julia 1.9 and its extension packages (+ weak dependencies) around the corner, would the main blocker for this be cleared? |
I think |
I can turn this PR into a package extension if someone can answer these two questions:
|
It was added to SciMLSensitivity |
Fixed in #1224 |
Solves a few AD related issues:
reshape
during pullbackExample