-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chain rules for FFT plans via AdjointPlans #67
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master JuliaLang/julia#67 +/- ##
==========================================
+ Coverage 87.32% 91.45% +4.13%
==========================================
Files 3 3
Lines 213 281 +68
==========================================
+ Hits 186 257 +71
+ Misses 27 24 -3
☔ View full report in Codecov by Sentry. |
c5c3755
to
6c81dfd
Compare
534ddd4
to
af74c54
Compare
40cce00
to
7cba04e
Compare
This should be ready for another review (with #69 as a dependency) |
7149781
to
675c61a
Compare
test/runtests.jl
Outdated
for f in (fft, ifft, bfft) | ||
test_frule(f, x, dims) | ||
test_rrule(f, x, dims) | ||
test_frule(f, complex_x, dims) | ||
test_rrule(f, complex_x, dims) | ||
end | ||
for pf in (plan_fft, plan_ifft, plan_bfft) | ||
test_frule(*, pf(x, dims) ⊢ NoTangent(), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the NoTangent
needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It tells ChainRulesTestUtils
to use a no tangent for the plan, i.e. do not try to differentiate w.r.t. the plan. See https://juliadiff.org/ChainRulesTestUtils.jl/dev/#Specifying-Tangents. It's necessary to manually specify this because with all the caching stuff and plans being mutable, ChainRules gets confused about the structure of Plan
and rand_tangent
errors
Unfortunately, there is one case where we do want to differentiate w.r.t. plan (as far as I can tell this is the only case), when someone makes a ScaledPlan
whose scale depends on the parameter:
using AbstractFFTs
using AbstractFFTs: Plan
using Zygote
include("test/testplans.jl")
julia> function f(x)
return sum(abs.(P * x))
end
f (generic function with 1 method)
# correct
julia> Zygote.gradient(f, [1,2,3])
([-0.732050807568877, 0.9999999999999996, 2.7320508075688776],)
julia> function f(x)
return sum(abs.(x[1] * P * x))
end
f (generic function with 1 method)
# silently wrong :(
julia> Zygote.gradient(f, [1,2,3])
([-0.732050807568877, 0.9999999999999996, 2.7320508075688776],)
I just spent some time trying to write an adjoint for ScaledPlan
by replacing it with the right-associative P.scale * (P.p * x)
and using rrule_via_ad
, but the fundamental issue was that ChainRules
was unable to come up with a tangent type for ScaledPlan
, for the same reason (mutability, circular references, etc.)
A few thoughts:
- Really, even if we keep the
pinv
field around,ScaledPlan
ought to not be mutable as the caching can just happen at the level of the inner plan. If this were fixed, it would be easy to write anrrule
, but this would be a separate PR that would probably require a lot of careful thought - Maybe we can somehow get the
ScaledPlan
differentiation to work, by adding a custom tangent type for the mutable struct. Don't have much ChainRules knowhow but I can give this a shot - I really really don't like that this silently gives incorrect results (although the current plan rules in Zygote do too for real FFTs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe making ScaledPlan
immutable isn't so hard... give me a minute:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. I assumed it's due to problems with CRTU but my main worry was exactly something like the ScaledPlan
case: that it masks differentiation issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To try to resolve this, I've:
a) Made ScaledPlan
s immutable in #72 (and done the same for AdjointPlan
s here)
b) Added a backwards rule for `ScaledPlan that differentiates w.r.t. the plan's scale too
c) Added tests. Because of JuliaDiff/ChainRulesTestUtils.jl#256, this turned out to be rather difficult. I couldn't come up with a way of using ChainRulesTestUtils
for properly checking the derivative w.r.t. ScaledPlan
without a PR there, and I have to pick my battles, so I ended up coming up with the best cludge I can think to preserve most of the automated testing w.r.t. the other tangents and adding a manual FD test for the plan scale. (In an ideal world I'd have been able to just get rid of the NoTangent
: once ChainRulesTestUtils
is able to accommodate this case, we should do that here.)
What do you think of the approach?
I guess this is still waiting for JuliaLang/julia#78? I think it would be good to also make ChainRulesCore a weak dependency on new Julia versions (see, e.g., how it is done in SpecialFunctions). |
@gaurav-arya you might want to rebase after 3a3f0e4. |
Question: what would the projection style be for real-to-real DCTs/ DSTs in FFTW.jl? |
I updated the PR (and fixed a few issues with types and functions that were not available in the extension). Since in-place plans are not supported (or at least not tested?) currently, maybe a final thing to add would be to check in the ChainRules definitions that |
I don't have the time right now to revisit this PR, but if it looks good and would be helpful, please feel free to fix anything remaining and merge. Thanks! |
Sure @devmotion, will take a quick look tonight. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of my comments are documentation suggestions, but there are also a few other minor ones.
if Base.mightalias(y, x) | ||
throw(ArgumentError("differentiation rules are not supported for in-place plans")) | ||
end | ||
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it inconsistent at all that here we use the tangent of the scale part of P
but none of the tangent of the wrapped plan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm it seems plans are assumed to be constant (AFAICT from the initial version of the PR) but the scaling might change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there's probably never a good reason a user would want (co)tangents for a Plan
. In almost every case the scale of a Plan
is just a constant that again a user would never want a (co)tangent for, but perhaps there is one user out there who does, so I can see the point in this.
scale = P.scale | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
project_scale = ChainRulesCore.ProjectTo(scale) | ||
function mul_scaledplan_pullback(ȳ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if the mul!(y, p, x, a, b)
API was supported by AbstractFFTs, because then ChainRules could also define an inplaceable thunk here, and Enzyme rules could avoid an allocation, but maybe outside the scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would require the FFT plan to support fused mul!
which isn't guaranteed. To create a fallback implementation, the plan must cache y
.
cache = get_cache(plan)
copy!(cache, y)
mul!(y, plan, x)
axpby!(b, cache, a, y)
Feels out of scope for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that mul!
has to guarantee allocation-free or fused computations (but maybe I'm wrong). Usually, !
only indicates that some (usually but not necessarily the first) argument is updated in-place but sometimes other arguments are updated as well and/or the update is not allocation-free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is my understanding that LinearAlgebra.mul!
is allocation-free. That is what gives it performance advantage over Base.*
. To my knowledge, no mutating LinearAlgebra
routine allocates a copy of the base array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is my understanding that
LinearAlgebra.mul!
is allocation-free.
I quickly checked the Julia repo, and there are a few open issues that show that at least in practice such a guarantee does not exist: https://github.com/JuliaLang/julia/issues/49332 JuliaLang/julia#46865 Arguably these are just bugs but on the other hand the docstring of mul!
also does not make any such guarantees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For both cases, the allocation size is independent of the array size indicating that the arrays are not being allocated. Looks like a spurious size tuple allocation to me.
Examples:
julia> versioninfo()
Julia Version 1.9.1
Commit 147bdf428cd (2023-06-07 08:27 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 8 × Apple M2
https://github.com/JuliaLang/julia/issues/49332
julia> using LinearAlgebra, BenchmarkTools
julia> A = rand(ComplexF64,4,4,1000,1000);
julia> B = similar(A);
julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));
julia> @btime mul!($b,$a,$a); # 4x4 * 4x4
311.283 ns (10 allocations: 608 bytes)
julia> A = rand(ComplexF64,128,128,10,10);
julia> B = similar(A);
julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));
julia> @btime mul!($b,$a,$a); # 128x128 * 128x128
170.542 μs (10 allocations: 608 bytes)
julia> N = 5_000;
julia> A = rand(N, N); B = rand(N, N); C = rand(N, N);
julia> @time mul!(C, A, B, true, true);
1.729141 seconds (1 allocation: 16 bytes)
julia> @time mul!(C, A, B);
1.637079 seconds
julia> @time A * B; # allocates N x N array
1.421422 seconds (2 allocations: 190.735 MiB, 0.13% gc time)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my understanding from skimming through the issues - and why I wrote arguably these could be considered to be bugs. My main point: There are no guarantees in Julia regarding allocation, the language or the JIT-compiler does not enforce any contracts, so it's only possible to document interfaces and trust people to implement them accordingly. But in the case of mul!
no such guarantees are documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think it makes sense for AbstractFTTs to ultimately support downstream packages implementing either 3-arg or 5-arg mul!
, with each defaulting to the other (yes stackoverflow, but if implementing one of them is required, then no overflow can exist). But I do also think this needn't happen in this PR.
src/definitions.jl
Outdated
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], | ||
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) | ||
) | ||
return convert(typeof(x), scale) ./ N .* (p.p \ x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could a parentheses be added here to make this easier to understand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where would you like to add them?
Now that adjoints are defined at the |
Co-authored-by: Seth Axen <[email protected]>
I'd argue no, they should be kept. Zygote also defined rules for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The potentially incorrect Zygote rules for FFT (FluxML#899) can be removed now that comprehensive Chain Rules have been added in JuliaMath/AbstractFFTs.jl#67
An
rfft
can be written asPF
whereF
is then x n
Fourier transform andP
is a projection operator that removes the redundant information due to conjuagate symmetry. Because ofP
, the adjoint of real FFTs (real inverse FFTs) require a special scaling before (after) applying the backwards transformation. As discussed in #63 this motivates supporting theBase.adjoint
operation for plans to simplify the writing of backward rules for AD.The following functions must be implemented by backends in order for
output_size(p::Plan)
andAdjointPlan
to work:projection_style(p::Plan)
which can either be:none
,:real
, or:real_inv
.irfft_dim(p::Plan)
, only for those plans with:real_inv
projection style, which gives the original length of the halved dimension.Using the adjoint plan, we can simplify the writing of backwards rules. I test the adjoint plans both directly and indirectly through tests of the rrule's.
NB: The interface has changed since the initial PR message. See the updated implementation docs in the PR for accurate info.