-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ReverseDiff cannot differentiate maxpool
#484
Comments
We could look into relaxing that condition, but it won't help much because ReverseDiff will still be very slow. The correct solution is adding a rule for pooling on the ReverseDiff side. Also, I'm not seeing what Lux and ComponentArrays have to do with any of this? A MWE with just NNlib and ReverseDiff would be better. |
You are right, I messed up. Meant to open this issue in the Lux repo 🙈 Should I close the issue here, and re-open at Lux or ReverseDiff? |
Again, my recommendation would be to remove Lux and ComponentArrays from your example. In other words, make the entrypoint the first call which only uses ReverseDiff and NNlib types/functions. You can find this from the stacktrace quite easily. It'll also help you with debugging why |
Lux.MaxPool
and ComponentArraysmaxpool
Alright. Was not as hard as expected 😅 import NNlib
import ReverseDiff
x = rand(1,1,1,1)
sz = (1,1)
pdims = NNlib.PoolDims(x, sz)
NNlib.maxpool(x, pdims) # isa Array{Float64, 4}
ReverseDiff.gradient(x) do _x
only(NNlib.maxpool(_x, pdims))
end # throws The stacktrace now is
pointing to the same issue as before: strict method signatures for I still cannot get import ReverseDiff: maxpool, PoolDims, @grad_from_chainrules
@grad_from_chainrules maxpool(x::TrackedArray, pdims::PoolDims; kwargs...) gives |
You're importing |
Oops, yes, that's just a typo, should have been import NNlib: maxpool, PoolDims
import ReverseDiff: @grad_from_chainrules
@grad_from_chainrules maxpool(x::TrackedArray, pdims::PoolDims; kwargs...) That's where I get |
If that's the case, then |
That's what I tried first, but then I get Regarding the method signatures: I did some very rudimentary test and broadened the method signatures for some pooling operations. In most places it should not pose problems. |
Widening the signature should be fine and I'd be happy to review a PR, but presumably you want acceptable performance and you won't get that without a custom rule. One thing would be to file an issue on the ReverseDiff side about the macro expansion problem. In the meantime, maybe try using |
Thanks! There is an initial draft to fix this now. I am in no hurry to get this working with ReverseDiff. My initial motivation was curiosity mainly: I often ran into situations where I had to restart my REPL or messed up in some other way, and Zygote would take ages for initial gradients. I wanted to see if ReverseDiff could do the initial gradient faster to better suit my messy interactive development at the time. That's also why overall performance would not have been an issue, if just the initial gradients were faster. |
I'd be curious to know how the performance stacks up :). Will have a look at #485 soon. |
* broaden pooling method signatures * pooling: revert type predition for target array * add ReverseDiff as test dependency * Update test/pooling.jl reformat whitespace in tests Co-authored-by: Brian Chen <[email protected]> * Update test/pooling.jl remove unused type parameters from function signatures Co-authored-by: Brian Chen <[email protected]> * Update src/impl/pooling_direct.jl remove unused type parameters from function signatures Co-authored-by: Brian Chen <[email protected]> * remove unused type parameters in pooling methods --------- Co-authored-by: Brian Chen <[email protected]>
Edit
I originally thought this to be an issue with Lux. That's not the case.
Please see this message for a brief example where ReverseDiff fails to take the gradient of a pooling operation.
Original Description
ReverseDiff cannot take derivatives of simple models involving Pooling Layers with respect to model parameters.
In the following example, two problems occur:
maxpool!
(see stacktrace below).I believe it is due to the strict type restrictions in the method signatures and the way ReverseDiff handles
similar
Example for Julia 1.8.5,
Heres the script:
And the trace for
The text was updated successfully, but these errors were encountered: