Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD through custom AbstractMatrix fails #815

Closed
simonmandlik opened this issue Oct 24, 2020 · 2 comments
Closed

AD through custom AbstractMatrix fails #815

simonmandlik opened this issue Oct 24, 2020 · 2 comments

Comments

@simonmandlik
Copy link

Is it possible with Zygote to build custom AbstractMatrix subtypes with custom adjoints? I would like to have a matrix, that would behave like a standard dense matrix, but before left-multiplying a data matrix it fills all missing elements in that matrix by values from a vector of parameters:

using Flux, Zygote
using ChainRulesCore
import ChainRulesCore: rrule
import Base: *

struct MyMatrix{T, U <: AbstractMatrix{T}, V <: AbstractVector{T}}
    W::U
    b::V
end

A::MyMatrix * B::AbstractMatrix{Union{Missing, Float64}} = A.W * _fill_in(A.b, B)

_fill_in(b, B) = _fill_mask(b, B)[1]
function rrule(::typeof(_fill_in), b, B)
    X, m = _fill_mask(b, B)
    X, Δ -> (NO_FIELDS, @thunk(_fill_in_db(Δ, m)), @thunk(_fill_in_dB(Δ, .!m)))
end

_fill_in_db(Δ, m) = (db = deepcopy(Δ); db[m] .= 0; sum(db, dims=2))
_fill_in_dB(Δ, m) = (dB = deepcopy(Δ); dB[m] .= 0; dB)

function _fill_mask(b, B)
    m = .!ismissing.(B)
    X = repeat(b, 1, size(B, 2))
    X[m] = B[m]
    X, m
end

and everything works as expected:

julia> B = [1.0 2.0; missing 3.0]
2×2 Array{Union{Missing, Float64},2}:
 1.0       2.0
  missing  3.0

julia> W, b = rand(2,2), rand(2)
([0.7271867875139644 0.9966562367102048; 0.7544782519268483 0.9148648966908355], [0.1202546404757574, 0.5564718739401007])

julia> MyMatrix(W, b) * B
2×2 Array{Float64,2}:
 1.2818   4.44434
 1.26357  4.25355

julia> gradient(A -> sum(A*B), MyMatrix(W, b))
((W = [3.0 3.556471873940101; 3.0 3.556471873940101], b = [0.0; 1.9115211334010402]),)

julia> gradient(B -> sum(MyMatrix(W,b)*B), B)
([1.4816650394408126 1.4816650394408126; 0.0 1.9115211334010402],)

However if I change the struct definition to:

struct MyMatrix{T, U <: AbstractMatrix{T}, V <: AbstractVector{T}} <: AbstractMatrix{T}
    W::U
    b::V
end

this stops working and gives this error:

julia> gradient(B -> sum(MyMatrix(W,b)*B), B)
ERROR: MethodError: no method matching copy(::Missing)
...

julia> gradient(A -> sum(A*B), MyMatrix(W, b))
ERROR: MethodError: no method matching copy(::Missing)
...

Is it possible to create custom AbstractMatrix subtypes in this way or does Zygote make it impossible?

@simonmandlik
Copy link
Author

Probably related to #811

@simonmandlik
Copy link
Author

The problem was that there is a definition of adjoint for multiplication of two matrices. I have tried to implement an adjoint specifically for * using rrule, which fails:

function rrule( A::typeof(*), ::MyMatrix, B::AbstractMatrix) = ...

once it is defined with @adjoint it works as expected:

Zygote.@adjoint A::MyMatrix * B::AbstractMatrix = ...

so I suspect this is a duplicate of #811

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant