-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use ArrayInterface.restructure in update! #1613
Conversation
It would be interesting to test this on some odd case like where Params contains a ComponentArray and the user just assumes that the gradient will be able to do |
Better handling is present in optimisers.jl |
function update!(opt, x, x̄) | ||
x .-= apply!(opt, x, x̄) | ||
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there solutions without ArrayInterface? Since this is (I think) intended as a quick fix pending a better design.
Perhaps it could just test ismutable
, and then use broadcasting to make a similar array (since x
is already assumed mutable)?
function update!(opt, x, x̄)
# quick fix for #1510, Zygote's output may not be mutable:
x̄r = ismutable(x̄) ? x̄ : x̄ .+ false .* x
x .-= apply!(opt, x, x̄r)
end
Or perhaps better to test x̄ isa DenseArray
or something, maybe there are mutable structs which aren't writeable...
function update!(opt, x, x̄::DenseArray)
x .-= apply!(opt, x, x̄r) # as before
end
function update!(opt, x, x̄)
x .-= apply!(opt, x, x̄ .+ false .* x) # fix for #1510, Zygote's output may not be mutable
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ismutable
is wrong: it doesn't tell you if an array is mutable. It tells you if a type is a mutable type, which is different.
julia> ismutable(Adjoint(rand(4)))
false
Base Julia does not have the right verbiage to answer these kinds of questions in an adequately generic way, which is the reason for ArrayInterface.jl becoming a standard trait interface to cover these kinds of issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArrayInterface.ismutable(x)
does signify "array is mutable"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not worried about carrying on the ArrayInterface dependence, but I do want to get the semantics right.
Is ArrayInterface.ismutable(x)
more appropriate than ArrayInterface.destructure
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes to Adjoint, but that would fail to the safe side, of making one too many copies.
ArrayInterface alone is about 0.5 seconds of loading time, Flux about 8.5 s without.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I guess I see this PR as a band-aid for the current structure of things.
I'd forgotten about that weird issue caused by Zeros. However, honest parameters expressed as x::SArray
might not be crazy, and a design that allowed them might be nice. Perhaps a smaller step would be a design that didn't need to mutate dx
to do x .+= ε .* dx
. If the former happened before the latter then some variants of this PR would need re-visiting... which seems OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a tough fix, essentially https://github.com/FluxML/Flux.jl/blob/v0.12.4/src/optimise/train.jl#L29 needs to be rewritten to something like:
xs[x] = update!(opt, x, gs[x])
Where update always returns the transformed parameter and can optionally mutate it to save an allocation. This would require implementing setindex
on Params
, but it's the only solution I can think of that preserves both the API and semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you replace what object is stored in the IdDict
, this won't be picked up by the model. I agree that you want some mutate-where-possible-but-always-return scheme, but maybe it has to take & return the whole model? IIRC this was being discussed at some point, Functors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the path we're walking down where dx
is not mutated and neither is x
when it is immutable is Optimisers.jl (based on Functors.jl to facilitate taking & returning the whole model). I would say it is not something we needed to address in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much, you'd still need to reconstruct afterwards. It's the inherent limitation of implicit params, which is why Optimisers.jl focuses on explicit params and being able to traverse through structures (à la JAX's pytrees) instead. I would love it if we could nuke implicit params overnight, but there are still a few UX things to figure out with explicit params before then.
@ChrisRackauckas is the test I added what you meant here? Although probably that was already working, the returned gradient from Zygote is a ComponentArray |
Yup that's the test. It was not working because if you did that on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this has come up before and Optimisers.jl needs some more time to bake, shall we get this merged?
@darsnack or @DhairyaLGandhi can I get an approval here? |
we don't have a timeline for the switch to Optimisers.jl yet, it will probably take months, so considerations about Optimisers.jl shouldn't stop us from fixing issues here and now EDIT: Instead of replying to Dhairya comment I inadvertently edited and replied inside his own comment, sorry for that (Carlo Lucibello) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree let's merge and not wait around for Optimisers.jl.
bors r+ |
1613: use ArrayInterface.restructure in update! r=CarloLucibello a=CarloLucibello Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989. Now `update!` handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement. Fix #1510 Co-authored-by: CarloLucibello <[email protected]>
bors r- |
Canceled. |
bors r+ |
Seeing as we have come around to switching out the zeros and have momentum now, I would say we should focus on moving ahead with Optimisers.jl sooner rather than later. |
Build succeeded: |
Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989.
Now
update!
handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement.Fix #1510