Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use ArrayInterface.restructure in update! #1613

Merged
merged 6 commits into from
Jun 17, 2021
Merged

use ArrayInterface.restructure in update! #1613

merged 6 commits into from
Jun 17, 2021

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Jun 10, 2021

Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989.
Now update! handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement.

Fix #1510

@ChrisRackauckas
Copy link
Member

It would be interesting to test this on some odd case like where Params contains a ComponentArray and the user just assumes that the gradient will be able to do grad.a == grad[1].

@DhairyaLGandhi
Copy link
Member

Better handling is present in optimisers.jl

Comment on lines 23 to +24
function update!(opt, x, x̄)
x .-= apply!(opt, x, x̄)
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there solutions without ArrayInterface? Since this is (I think) intended as a quick fix pending a better design.

Perhaps it could just test ismutable, and then use broadcasting to make a similar array (since x is already assumed mutable)?

function update!(opt, x, x̄)
  # quick fix for #1510, Zygote's output may not be mutable:
  x̄r = ismutable(x̄) ? x̄ : x̄ .+ false .* x
  x .-= apply!(opt, x, x̄r)
end

Or perhaps better to test x̄ isa DenseArray or something, maybe there are mutable structs which aren't writeable...

function update!(opt, x, x̄::DenseArray)
  x .-= apply!(opt, x, x̄r) # as before
end
function update!(opt, x, x̄)
  x .-= apply!(opt, x, x̄ .+ false .* x)  # fix for #1510, Zygote's output may not be mutable
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ismutable is wrong: it doesn't tell you if an array is mutable. It tells you if a type is a mutable type, which is different.

julia> ismutable(Adjoint(rand(4)))
false

Base Julia does not have the right verbiage to answer these kinds of questions in an adequately generic way, which is the reason for ArrayInterface.jl becoming a standard trait interface to cover these kinds of issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayInterface.ismutable(x) does signify "array is mutable"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not worried about carrying on the ArrayInterface dependence, but I do want to get the semantics right.
Is ArrayInterface.ismutable(x) more appropriate than ArrayInterface.destructure?

Copy link
Member

@mcabbott mcabbott Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to Adjoint, but that would fail to the safe side, of making one too many copies.

ArrayInterface alone is about 0.5 seconds of loading time, Flux about 8.5 s without.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I guess I see this PR as a band-aid for the current structure of things.

I'd forgotten about that weird issue caused by Zeros. However, honest parameters expressed as x::SArray might not be crazy, and a design that allowed them might be nice. Perhaps a smaller step would be a design that didn't need to mutate dx to do x .+= ε .* dx. If the former happened before the latter then some variants of this PR would need re-visiting... which seems OK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a tough fix, essentially https://github.com/FluxML/Flux.jl/blob/v0.12.4/src/optimise/train.jl#L29 needs to be rewritten to something like:

xs[x] = update!(opt, x, gs[x])

Where update always returns the transformed parameter and can optionally mutate it to save an allocation. This would require implementing setindex on Params, but it's the only solution I can think of that preserves both the API and semantics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you replace what object is stored in the IdDict, this won't be picked up by the model. I agree that you want some mutate-where-possible-but-always-return scheme, but maybe it has to take & return the whole model? IIRC this was being discussed at some point, Functors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the path we're walking down where dx is not mutated and neither is x when it is immutable is Optimisers.jl (based on Functors.jl to facilitate taking & returning the whole model). I would say it is not something we needed to address in this PR.

Copy link
Member

@ToucheSir ToucheSir Jun 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much, you'd still need to reconstruct afterwards. It's the inherent limitation of implicit params, which is why Optimisers.jl focuses on explicit params and being able to traverse through structures (à la JAX's pytrees) instead. I would love it if we could nuke implicit params overnight, but there are still a few UX things to figure out with explicit params before then.

@CarloLucibello
Copy link
Member Author

It would be interesting to test this on some odd case like where Params contains a ComponentArray and the user just assumes that the gradient will be able to do grad.a == grad[1].

@ChrisRackauckas is the test I added what you meant here? Although probably that was already working, the returned gradient from Zygote is a ComponentArray

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas is the test I added what you meant here? Although probably that was already working, the returned gradient from Zygote is a ComponentArray

Yup that's the test. It was not working because if you did that on the sum example, you'd get back a FillArray in previous versions. So it's not that it was always not working, it was Zygote gets a little fussy and chooses when it wanted to obey those rules 😆

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this has come up before and Optimisers.jl needs some more time to bake, shall we get this merged?

@darsnack darsnack mentioned this pull request Jun 14, 2021
@CarloLucibello
Copy link
Member Author

@darsnack or @DhairyaLGandhi can I get an approval here?

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jun 17, 2021

What needs to bake in optimisers.jl?

we don't have a timeline for the switch to Optimisers.jl yet, it will probably take months, so considerations about Optimisers.jl shouldn't stop us from fixing issues here and now

EDIT: Instead of replying to Dhairya comment I inadvertently edited and replied inside his own comment, sorry for that (Carlo Lucibello)

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree let's merge and not wait around for Optimisers.jl.

@CarloLucibello
Copy link
Member Author

bors r+

bors bot added a commit that referenced this pull request Jun 17, 2021
1613: use ArrayInterface.restructure in update! r=CarloLucibello a=CarloLucibello

Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989. 
Now `update!` handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement. 

Fix #1510 


Co-authored-by: CarloLucibello <[email protected]>
@CarloLucibello
Copy link
Member Author

bors r-

@bors
Copy link
Contributor

bors bot commented Jun 17, 2021

Canceled.

@CarloLucibello
Copy link
Member Author

bors r+

@DhairyaLGandhi
Copy link
Member

Seeing as we have come around to switching out the zeros and have momentum now, I would say we should focus on moving ahead with Optimisers.jl sooner rather than later.

@darsnack darsnack mentioned this pull request Jun 17, 2021
@bors
Copy link
Contributor

bors bot commented Jun 17, 2021

Build succeeded:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Minimizing sum fails
6 participants