-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
how to selectively take structural gradient #1042
Comments
What do you mean by gradient masking here? Also this is needed since while the gradient for an argument may not be explicitly asked for, but might be required to compute the gradient of a different argument. Forcing that to |
Preventing updates during optimization could be accomplished with a helper like https://optax.readthedocs.io/en/latest/api.html?highlight=mask#optax.masked. That said, it doesn't account for the scenario where you want to save memory by not holding onto gradients for certain parameters that won't be updated. |
Functors.trainable and ChainRulesCore.ProjectTo have quite a bit in common, it's possible they should get to know each other better. I'm not precisely sure why this doesn't work today, but with #1044 it might go something like: julia> import Zygote.ChainRulesCore: ProjectTo
julia> ProjectTo(a::A) = dA::NamedTuple -> ((; a1 = dA1.a1, eps=nothing, b=nothing),);
julia> gradient(loss, a)
((a1 = Fill(1.0, 3), eps = 1.0, b = (b1 = Fill(1.0, 2), b2 = nothing)),)
((a1 = Fill(1.0, 3), eps = nothing, b = nothing),) # is what I hoped for |
The wrench in the works is that Functors doesn't have a |
Is it because |
Oh right, thanks both. I guess there are many details of this union I don't see yet. But Functors/Optimisers are interested in AD with nested structs, and there might be a nice ChainRules-level way to encode things. |
There are two pieces here: (1) not updating non-trainable parameters, and (2) not computing gradients for non-trainable parameters. For (1), Optimisers.jl uses Functors.jl to walk over the structure and the nested gradient tuple to apply updates. Thanks to FluxML/Functors.jl#14, we can know limit that walk to the parameters defined by For (2), if function make_model(some_hyperparams...)
# do stuff with hyper-params to make W and b
return Dense(W, b) # let's suppose only W is trainable
end
gradient(ps -> loss(make_model(ps...)(x)), ps) Here it wouldn't make sense for the pullback of Also, if this is somewhere in the middle of the computation, then I would hope the memory gets re-used once that unnecessary gradient is not in the following pullbacks. I think this is really a concern for only the inputs to the full computation. |
This is the blessing and curse of Zygote supporting differentiation of arbitrary structs. AFAIK, there is no way to provide it additional information about what fields should be |
Right, PyTorch autograd's |
Yup. Now if we had a function like |
Another (possibly complementary) approach more in line with |
Maybe #966 could bring some improvements in that regard. |
For my own edification, do thunks help with deeply nested struct or tangent fields? I can wrap my head around how an entire argument might be excluded from evaluation, but not a piece of one. |
I have to admit I'm not entirely sure myself, resp. if it will be possible to make the pullback(s) for the struct creation smart enough. Maybe we will need some kind of hinting procedure at some point, so the user can specify what quantities they want the gradient for, like Enzyme has. |
In Flux, we typically apply
@functor
to a type for 2 purposes:gpu
Zygote.Params
for gradient calculation (this is done byFlux.params(model)
).When we what to distinguish the two behaviors, we use
Flux.trainable
for the parameters collection.This is an
Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to
a.a1
should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect toa.a1
is exposedWith explicit gradient instead, everything is computed and exposed
This is bad since we would like to feed this to an
update!
function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the followingI see two possibilities:
@functor
/trainable
awareThe text was updated successfully, but these errors were encountered: