Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing over AbstractMatrix subtypes #2559

Closed
simonmandlik opened this issue Dec 29, 2024 · 4 comments
Closed

Optimizing over AbstractMatrix subtypes #2559

simonmandlik opened this issue Dec 29, 2024 · 4 comments

Comments

@simonmandlik
Copy link
Contributor

simonmandlik commented Dec 29, 2024

This is a related issue to #2045, but for brevity I will use a simpler example here.

The question is, how to define a custom AbstractMatrix type and plug it into Flux.jl AD and optimization engine.

Say I have

using Flux
import Base: *

struct MyMatrix{T <: Number, U <: AbstractMatrix{T}}
    A::U
    B::U
end

Base.show(io::IO, ::MyMatrix) = print(io, "MyMatrix")
Base.show(io::IO, ::MIME"text/plain", ::MyMatrix) = print(io, "MyMatrix")
Base.size(M::MyMatrix) = size(M.A)
Base.getindex(M::MyMatrix, i, j) = M.A[i, j] + M.B[i, j]

Flux.@layer MyMatrix

A::MyMatrix * b::AbstractVector = my_mul(A.A, A.B, b)
my_mul(A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector) = A * b .+ B * b

which is a matrix performing standard matrix-vector multiplication, but the matrix is represented as two matrices summed together. So the following will work:

M = MyMatrix(rand(3, 3), rand(3, 3))
x = rand(3)
M * x

Now, computing gradient is possible:

julia> Flux.gradient(m -> sum(m * x), M)

((A = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861], B = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861]),)

The trouble begins when I want to make MyMatrix a subtype of AbstractMatrix (see for example #2045 why that would make sense):

struct MyMatrix{T <: Number, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
    A::U
    B::U
end

Gradient cannot be computed now:

julia> Flux.gradient(m -> sum(m * x), M)

ERROR: MethodError: no method matching size(::MyMatrix{Float64, Matrix{Float64}})
The function `size` exists, but no method is defined for this combination of argument types.
You may need to implement the `length` and `size` methods for `IteratorSize` `HasShape`.

This is due to some default ChainRules rule, let's opt out from it (as discussed e.g here FluxML/Zygote.jl#1146), and create a custom (dummy) rrule and ProjectTo:

using ChainRulesCore

ChainRulesCore.@opt_out ChainRulesCore.rrule(::typeof(Base.:*), ::MyMatrix, ::ChainRulesCore.AbstractVecOrMat{<:Union{Real, Complex}})

function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector)
    result = A * b .+ B * b
    result, Δ -> (NoTangent(), zero(A), zero(B), zero(b)) # dummy
end

ChainRulesCore.ProjectTo(M::MyMatrix) = ChainRulesCore.ProjectTo{typeof(M)}(
    A = ChainRulesCore.ProjectTo(M.A),
    B = ChainRulesCore.ProjectTo(M.B)
)
julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], B = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]),)

So far so good, but how do I plug this into Flux.jl machinery?

Flux.trainables(M) returns a single-element array containing M, not M.A and M.B. Neither of the following works:

  • Flux.@layer MyMatrix trainable=(A,B) has no effect
  • neither does Flux.trainable(M::MyMatrix) = (A=M.A, B=M.B)
  • Flux.Optimisers.isnumeric(::MyMatrix) = false leads to some internal error:
ERROR: MethodError: no method matching _trainable(::Tuple{}, ::@NamedTuple{A::Matrix{Float64}, B::Matrix{Float64}})
The function `_trainable` exists, but no method is defined for this combination of argument types.

And if I try to setup a simple training example, I run into another error:

model = Dense(M, rand(3));
julia> opt_state = Flux.setup(Adam(), model)

ERROR: model must be fully mutable for `train!` to work, got `x::MyMatrix{Float64, Matrix{Float64}}`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::MyMatrix{Float64, Matrix{Float64}}) = true`

In #2045 there is some discussion regarding AbstractArray subtypes, but it is no longer relevant as implicit parametrization is now deprecated.

Thank you very much in advance

@ToucheSir
Copy link
Member

Ordinarily I would move this issue to the Zygote.jl repo since almost none of the code for the "Flux.jl AD and optimization engine" actually lives in Flux itself (AD is in Zygote/Enzyme, and optimization rules are in Optimisers.jl). But it looks like there are two separate issues about AD and model parameters, so I'll leave it here for now.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 30, 2024

Tackling the easier issue first:

And if I try to setup a simple training example, I run into another error:
...

I think the error message there is relatively clear, but if not PRs welcome. I think the docstring for Flux.setup is more helpful here:

This is a version of Optimisers.setup, and is the first step before using train!. It differs from Optimisers.setup in that it:

  • has one extra check for mutability (since Flux expects to mutate the model in-place, while Optimisers.jl is designed to return an updated model)

So your options would be to either define Optimisers.maywrite if MyMatrix is mutable, or to use Optimisers.jl directly since it's more flexible with the array types it accepts (e.g. StaticArrays).

@mcabbott
Copy link
Member

mcabbott commented Dec 30, 2024

Flux.@layer MyMatrix trainable=(A,B) has no effect

Part of what's going on here is that Functors.jl now (v0.5, i.e. Flux >= 0.15) recurses into all structs, except for a small list including AbstractArray{<:Number}. Hence that package, and Optimisers.jl, regard your second MyMatrix as being "leaf types", and want to update them by e.g. M .+= ....

But I think you want it to unpack down to A, B, regard those as the real parameters. You can probably tell Functors.jl to do this by defining a method of functor(::typeof(M), M), which is what @functor does. This is also what @layer did on Flux <= 0.14. But I'm not certain that @functor MuMatrix will override the leaf-type definitions, haven't tried.

This is due to some default ChainRules rule

You do need to define size and getindex! As you did above for the other type you also call MyMatrix.

Independent of Functors.jl, all things Zygote/CR believe that any x::AbstractArray{<:Number} should have as its "natural gradient" dx::AbstractArray{<:Number}. While something like Zygote.gradient(x -> sum(my_mul(x, b)), M) may (*) return a "structural gradient" (; A, B), because my_mul accesses the fields, something like Zygote.gradient(x -> sum(hcat(x, b)), M) will not. Zygote will not see your definition of getindex, it will see hcat and call the rule for that. That's the core reason it thinks this way -- letting Zygote differentiate all the way down to getindex would be a performance disaster.

Zygote/CR's point of view is that you have opted into all these generic rules by defining a supertype AbstractMatrix. The simplest way to opt out is not to define this supertype.

(*) I say may return a "structural gradient", but probably you will get an error when it tries to convert this back to "natural" form & doesn't know how.

@simonmandlik
Copy link
Contributor Author

Thank you both!

Zygote will not see your definition of getindex, it will see hcat and call the rule for that. That's the core reason it thinks this way -- letting Zygote differentiate all the way down to getindex would be a performance disaster.

Yes, a specialization for hcat would be needed as well if I wanted to AD through it.

I think the error message there is relatively clear

Yeah, sure. Sorry for not being clear about this. As far as I understand it, there are several ways one could take. The error message suggests one that is certainly possible, but I was asking about the bigger picture and which solution would be recommended (and least painful going forward). There are three possibilities:

  1. avoid subtyping AbstractMatrix as @mcabbott suggests. Certainly possible, but while this subtyping brings something "bad" from the AD POV, it allows reuse of many other useful definitions across the ecosystem. And some stuff is even not possible for types that are not AbstractMatrix. For example, Flux.Dense only accepts AbstractMatrix.
  2. represent the gradient of MyMatrix structurally as MyMatrix and equip it with everything it should need for updating itself (like the mutation, etc...)
  3. represent the gradient as the two subarrays (A and B) and make sure that it is treated that way everywhere.

This is what I ended up with for the latter:

using Flux, ChainRulesCore
import Base: *

struct MyMatrix{T <: Number, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
    A::U
    B::U
end

Base.show(io::IO, ::MyMatrix) = print(io, "MyMatrix")
Base.show(io::IO, ::MIME"text/plain", ::MyMatrix) = print(io, "MyMatrix")

Flux.Functors.@functor MyMatrix (A, B)

M::MyMatrix * b::AbstractVector = my_mul(M, b)
my_mul(M::MyMatrix, b::AbstractVector) = M.A * b .+ M.B * b

ChainRulesCore.@opt_out ChainRulesCore.rrule(
    ::typeof(Base.:*), ::MyMatrix,
    ::ChainRulesCore.AbstractVecOrMat{<:Union{Real, Complex}}
)

function ChainRulesCore.rrule(::typeof(my_mul), M::MyMatrix, b::AbstractVector)
    result = M.A * b .+ M.B * b
    result, Δ -> (
        NoTangent(),
        Tangent{typeof(M)}(A = zero(M.A), B = zero(M.B)),
        NoTangent(),
    ) # dummy
end

ChainRulesCore.ProjectTo(M::MyMatrix) = ChainRulesCore.ProjectTo{typeof(M)}(
    A = ChainRulesCore.ProjectTo(M.A),
    B = ChainRulesCore.ProjectTo(M.B)
)

julia> Flux.trainables(M)
2-element Vector{AbstractArray}:
 [0.13690525217120197 0.03255308066433349 0.42366049897378943; 0.5216667429485774 0.24198352680094837 0.7630237433363294; 0.7263307224554326 0.6955152108498136 0.2181187077029928]
 [0.9270942999694958 0.03030816621932353 0.6213207428554142; 0.17782636752227254 0.8921299827557096 0.7451039263343977; 0.2222917269319613 0.7359569279723447 0.6166570979464053]

julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], B = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]),)

julia> opt = Flux.setup(Adam(), M)
(A = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], (0.9, 0.999))), B = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], (0.9, 0.999))))

julia> Flux.train!((m, x) -> sum(m * x), M, [(x)], opt)
Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants