-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add chain rules for function calls without dims #83
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #83 +/- ##
==========================================
+ Coverage 84.13% 87.71% +3.57%
==========================================
Files 2 2
Lines 208 236 +28
==========================================
+ Hits 175 207 +32
+ Misses 33 29 -4
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Could we document that downstream packages have to implement the two-argument methods but not the ones without Generally, the approach in the PR won't work anyway if a package has only implemented the one-argument version. |
I'm not following how your suggested approach could mean that the extra rules here aren't needed? One approach could be to replace this line: AbstractFFTs.jl/src/definitions.jl Line 62 in 7d698db
with $f(x::AbstractArray) = $f(x::AbstractArray, 1:ndims(x)) Then, we wouldn't need the extra rule for no |
I actually went in the opposite direction and generalized the chain rules to directly work with and without a
It makes the rules here a bit more complex, but now no assumptions whatsoever are made on what signatures downstream implementations support, so this is arguably the most robust solution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this PR is still suboptimal and a better design would be desirable. With the latest changes now the signatures of the rules differs from the signatures of fft
etc.
I think the cleanest solution is to only work with versions of fft
etc. that implement dims
and forward fft(x)
etc. to the two-argument version. Otherwise we have to copy all rules and just remove dims
everywhere. I think we should avoid such a code duplication.
# we explicitly handle both unprovided and provided dims arguments in all rules, which | ||
# results in some additional complexity here but means no assumptions are made on what | ||
# signatures downstream implementations support. | ||
function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not happy about this PR because it means the signature of the AD rules is different from the signatures of fft
etc. - we do not support dims = nothing
in any of these methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A default positional argument simply expands to separate dispatches on the signatures fft(x, dims)
and fft(x)
. The dims=nothing
is just a way of sharing logic in these cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I would not say the signatures are different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is: You can call frule(.., fft, x, nothing)
but you cannot call fft(x, nothing)
. This breaks the correspondence between the primal function and the rules, and makes the signatures inconsistent.
There is no clean way to share code as long as fft(x)
and fft(x, dims)
are completely separate. Introducing fft(x) = fft(x, 1:ndims(x))
or fft(x) = fft(x, nothing)
, and demanding that downstream packages implement fft(x, dims)
only would solve these issues. Otherwise you have to copy the code or use something like @eval
to do it for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point, I didn't realize the nothing
case. Sharing code would be easy enough with a shared helper function, e.g. replacing my current function with something like _fft_rrule
and calling it in both cases, so that all the dispatches are correct. If you're opposed to that, I can look into how to modify src/definitions.jl
to support your solution.
See my response to your comment -- I don't really agree that the signatures are different, and even explicitly writing out separate rules for Also, I see it as an inherit benefit to avoid modifying |
So it's possible to avoid code copying and get the dispatches right if one makes a helper function e.g. |
Addresses issue with existing chain rules observed in FluxML/Zygote.jl#1386