-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizing over AbstractMatrix
subtypes
#2559
Comments
Ordinarily I would move this issue to the Zygote.jl repo since almost none of the code for the "Flux.jl AD and optimization engine" actually lives in Flux itself (AD is in Zygote/Enzyme, and optimization rules are in Optimisers.jl). But it looks like there are two separate issues about AD and model parameters, so I'll leave it here for now. |
Tackling the easier issue first:
I think the error message there is relatively clear, but if not PRs welcome. I think the docstring for
So your options would be to either define |
Part of what's going on here is that Functors.jl now (v0.5, i.e. Flux >= 0.15) recurses into all structs, except for a small list including But I think you want it to unpack down to
You do need to define Independent of Functors.jl, all things Zygote/CR believe that any Zygote/CR's point of view is that you have opted into all these generic rules by defining a supertype (*) I say may return a "structural gradient", but probably you will get an error when it tries to convert this back to "natural" form & doesn't know how. |
Thank you both!
Yes, a specialization for
Yeah, sure. Sorry for not being clear about this. As far as I understand it, there are several ways one could take. The error message suggests one that is certainly possible, but I was asking about the bigger picture and which solution would be recommended (and least painful going forward). There are three possibilities:
This is what I ended up with for the latter: using Flux, ChainRulesCore
import Base: *
struct MyMatrix{T <: Number, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
A::U
B::U
end
Base.show(io::IO, ::MyMatrix) = print(io, "MyMatrix")
Base.show(io::IO, ::MIME"text/plain", ::MyMatrix) = print(io, "MyMatrix")
Flux.Functors.@functor MyMatrix (A, B)
M::MyMatrix * b::AbstractVector = my_mul(M, b)
my_mul(M::MyMatrix, b::AbstractVector) = M.A * b .+ M.B * b
ChainRulesCore.@opt_out ChainRulesCore.rrule(
::typeof(Base.:*), ::MyMatrix,
::ChainRulesCore.AbstractVecOrMat{<:Union{Real, Complex}}
)
function ChainRulesCore.rrule(::typeof(my_mul), M::MyMatrix, b::AbstractVector)
result = M.A * b .+ M.B * b
result, Δ -> (
NoTangent(),
Tangent{typeof(M)}(A = zero(M.A), B = zero(M.B)),
NoTangent(),
) # dummy
end
ChainRulesCore.ProjectTo(M::MyMatrix) = ChainRulesCore.ProjectTo{typeof(M)}(
A = ChainRulesCore.ProjectTo(M.A),
B = ChainRulesCore.ProjectTo(M.B)
)
julia> Flux.trainables(M)
2-element Vector{AbstractArray}:
[0.13690525217120197 0.03255308066433349 0.42366049897378943; 0.5216667429485774 0.24198352680094837 0.7630237433363294; 0.7263307224554326 0.6955152108498136 0.2181187077029928]
[0.9270942999694958 0.03030816621932353 0.6213207428554142; 0.17782636752227254 0.8921299827557096 0.7451039263343977; 0.2222917269319613 0.7359569279723447 0.6166570979464053]
julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], B = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]),)
julia> opt = Flux.setup(Adam(), M)
(A = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], (0.9, 0.999))), B = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], (0.9, 0.999))))
julia> Flux.train!((m, x) -> sum(m * x), M, [(x)], opt)
Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 |
This is a related issue to #2045, but for brevity I will use a simpler example here.
The question is, how to define a custom
AbstractMatrix
type and plug it intoFlux.jl
AD and optimization engine.Say I have
which is a matrix performing standard matrix-vector multiplication, but the matrix is represented as two matrices summed together. So the following will work:
Now, computing gradient is possible:
The trouble begins when I want to make
MyMatrix
a subtype ofAbstractMatrix
(see for example #2045 why that would make sense):Gradient cannot be computed now:
This is due to some default
ChainRules
rule, let's opt out from it (as discussed e.g here FluxML/Zygote.jl#1146), and create a custom (dummy)rrule
andProjectTo
:So far so good, but how do I plug this into
Flux.jl
machinery?Flux.trainables(M)
returns a single-element array containingM
, notM.A
andM.B
. Neither of the following works:Flux.@layer MyMatrix trainable=(A,B)
has no effectFlux.trainable(M::MyMatrix) = (A=M.A, B=M.B)
Flux.Optimisers.isnumeric(::MyMatrix) = false
leads to some internal error:And if I try to setup a simple training example, I run into another error:
In #2045 there is some discussion regarding
AbstractArray
subtypes, but it is no longer relevant as implicit parametrization is now deprecated.Thank you very much in advance
The text was updated successfully, but these errors were encountered: