Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure pullback of exp works for immutable arrays #381

Merged
merged 3 commits into from
Mar 1, 2021

Conversation

sethaxen
Copy link
Member

Fixes #380 by ensuring that the cotangent used by the pullback of exp is mutable.

@Roger-luo can you confirm that this fixes the bug?

@sethaxen sethaxen changed the title Expmut Ensure pullback of exp works for immutable arrays Feb 18, 2021
@Roger-luo
Copy link

yes! thanks!

@sethaxen sethaxen requested a review from oxinabox February 19, 2021 00:41
∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates)
# Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by
# the default _matfun_frechet_adjoint!
∂X = Matrix(ΔX')'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weirdly copy promises to return something mutable and is probably cleaner than this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess copy would always allocate, where a type conversion won't? or should there be some dispatch using a mutable trait ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On checking, I was apparently wrong about copy.
It doesn't nesc return something mutable.
I was confused by the fact that it makes some view-like things (including Adjoint and SubArray, but not including Diagonal) into an arrays.

convert vs Constructors is a thing though.

https://docs.julialang.org/en/v1/manual/conversion-and-promotion/#Mutable-collections

convert(T, x) is expected to return the original x if x is already of type T. In contrast, if T is a mutable collection type then T(x) should always make a new collection (copying elements from x).

So perhaps this should be a convert if we want to avoid allocating unnesc?


ChainRulesCore does actually have a trait that might be suitable for this. If we really wanted.
It's part of the mechanics for doing inplace gradient accumulation

is_inplaceable_destination(x) -> Bool

Returns true if x is suitable for for storing inplace accumulation of gradients.
For arrays this boils down x .= y if will work to mutate x, if y is an appropriate differential.
Wrapper array types do not need to overload this if they overload Base.parent, and are is_inplaceable_destination if and only if their parent array is.
Other types should overload this, as it defaults to false.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up going with a mixture of these both: Using the trait to decide if we should do anything and then using convert, which in this case will always allocate a copy. (convert(Matrix, X') will allocate unless X is an Adjoint{T,Matrix{T}}, but since that type is inplaceable, the trait will bypass the convert anyways).

@codecov-io
Copy link

codecov-io commented Feb 19, 2021

Codecov Report

Merging #381 (a17f9d1) into master (f0fead0) will decrease coverage by 10.42%.
The diff coverage is 100.00%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master     #381       +/-   ##
===========================================
- Coverage   97.72%   87.29%   -10.43%     
===========================================
  Files          19       19               
  Lines        1495     1244      -251     
===========================================
- Hits         1461     1086      -375     
- Misses         34      158      +124     
Impacted Files Coverage Δ
src/rulesets/LinearAlgebra/matfun.jl 100.00% <100.00%> (ø)
src/rulesets/Base/evalpoly.jl 0.00% <0.00%> (-97.68%) ⬇️
src/rulesets/Base/utils.jl 0.00% <0.00%> (-80.00%) ⬇️
src/ChainRules.jl 66.66% <0.00%> (-33.34%) ⬇️
src/rulesets/Statistics/statistics.jl 66.66% <0.00%> (-23.34%) ⬇️
src/rulesets/LinearAlgebra/utils.jl 66.66% <0.00%> (-20.00%) ⬇️
src/rulesets/LinearAlgebra/symmetric.jl 84.15% <0.00%> (-14.18%) ⬇️
src/rulesets/LinearAlgebra/structured.jl 92.04% <0.00%> (-6.84%) ⬇️
src/rulesets/LinearAlgebra/factorization.jl 95.50% <0.00%> (-2.19%) ⬇️
src/rulesets/LinearAlgebra/norm.jl 98.24% <0.00%> (-1.76%) ⬇️
... and 8 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f0fead0...a17f9d1. Read the comment docs.

@sethaxen
Copy link
Member Author

@Roger-luo can you confirm that the latest version still works for you?

∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates)
# Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by
# the default _matfun_frechet_adjoint!
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just:

Suggested change
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')'
∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if ∂X' is an Adjoint, then _matfun_frechet_adjoint! will copy it to make it non-Adjoint. This way, we do only one allocation instead of 2. An alternative is to bypass _matfun_frechet_adjoint! to call _matfun_frechet! directly, but to me this seems cleaner. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if it is an Adjoint then isn't it going to be mutable?
Or not becuase it might be an Adjoint{FillArray} ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this makes sense, leave a comment to that effect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you suggest a change that would make this comment clearer?

# Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by
# the default _matfun_frechet_adjoint!

@GiggleLiu
Copy link

GiggleLiu commented Feb 20, 2021

Actually I and Roger discussed a lot about whether it is an issue of tr or exp. Are we fixing this issue for every function that calls setindex!? or just fix tr and sum?

@sethaxen
Copy link
Member Author

Actually I and Roger discussed a lot about whether it is an issue of tr or exp. Are we fixing this issue for every function that calls setindex!? or just fix tr and sum?

I'm not sure I understand the question. This will fix for the case where a user passes an immutable array to the pullback for exp. That crops up with Zygote's rule for tr but is generic. I wouldn't be surprised if there are other rrules in ChainRules that make incorrect assumptions about mutability, but perhaps that's best discussed in #380 or another issue.

@mcabbott
Copy link
Member

Does this mutate what's received by the pullback? Is that the cause of this:

julia> gradient(x -> 2sum(abs, exp(x)'), [1 2; 3 4.0])
([226.54923657197995 427.40643795569565; 309.849813509321 579.2191099506263],)

julia> gradient(x -> sum(abs, (exp(x) + exp(x))'), [1 2; 3 4.0])
([169.91192742898497 320.5548284667717; 232.38736013199076 434.4143324629697],)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

combining exp and tr cause setindex error from Zygote
6 participants