Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalize Composite{SVD} #388

Closed
wants to merge 3 commits into from
Closed

Canonicalize Composite{SVD} #388

wants to merge 3 commits into from

Conversation

mzgubic
Copy link
Member

@mzgubic mzgubic commented Mar 23, 2021

At the moment:

julia> x = Composite{SVD{Float64,Float64,Array{Float64,2}}}(V = [0.6662838804680036 -0.5008780874375076; 0.729969913633919 0.228331721577141],)

julia> ChainRules.canonicalize(x)
ERROR: ArgumentError: Composite fields do not match primal fields.
Composite fields: (:V,). Primal (SVD{Float64,Float64,Array{Float64,2}}) fields: (:U, :S, :Vt)

This seems to make sense because V is not a field of SVD. However, the rrule for svd expects the Ȳ::Composite to contain U, S, and V fields:

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback::Composite)
# `getproperty` on `Composite`s ensures we have no thunks.
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
end

I don't know much about svd factorisation so can't really tell why this is the case.

It kind of feels wrong to overload canonicalize in this way, but I can't quite tell why. Any thoughts? Also, which file to put this in, it's not really a rule so I don't think it belongs in the rulesets folder.

Came across this in FluxML/Zygote.jl#922 (comment)

@oxinabox oxinabox added the needs-careful-thought A reminder that this thing is not obviouis and care must be taken when working on it label Mar 23, 2021
@sethaxen
Copy link
Member

It seems like it would be more straightforward to modify the rrule for getproperty(::SVD, ::Symbol) to store the cotangent of Vt instead of V. After all SVD actually has Vt as a field, and V is just its Adjoint.

@mzgubic
Copy link
Member Author

mzgubic commented Mar 23, 2021

It seems like it would be more straightforward to modify the rrule for getproperty(::SVD, ::Symbol) to store the cotangent of Vt instead of V. After all SVD actually has Vt as a field, and V is just its Adjoint.

Thanks for taking a look, that makes sense to me. For some reason I assumed that there is a complicated reason behind using V, should have checked that.

For getproperty pullback:

function getproperty_svd_pullback(Ȳ)
C = Composite{T}
∂F = if x === :U
C(U=Ȳ,)
elseif x === :S
C(S=Ȳ,)
elseif x === :V
C(V=Ȳ,)
elseif x === :Vt
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
end
return NO_FIELDS, ∂F, DoesNotExist()
end

can we just do something like:

....
        elseif x === :V
            C(Vt=Ȳ',)
        elseif x === :Vt
            C(Vt=Ȳ,)
...

Or is there more to it?

@sethaxen
Copy link
Member

can we just do something like:

....
        elseif x === :V
            C(Vt=Ȳ',)
        elseif x === :Vt
            C(Vt=Ȳ,)
...

Or is there more to it?

I think that's all there is to it. You could add copy(Ȳ') to make Vt not an Adjoint, but this is unnecessary, so long as the rrules for svd still work when the cotangent for Vt is an Adjoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-careful-thought A reminder that this thing is not obviouis and care must be taken when working on it
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants