You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is it possible with Zygote to build custom AbstractMatrix subtypes with custom adjoints? I would like to have a matrix, that would behave like a standard dense matrix, but before left-multiplying a data matrix it fills all missing elements in that matrix by values from a vector of parameters:
using Flux, Zygote
using ChainRulesCore
import ChainRulesCore: rrule
import Base:*struct MyMatrix{T, U <:AbstractMatrix{T}, V <:AbstractVector{T}}
W::U
b::Vend
A::MyMatrix* B::AbstractMatrix{Union{Missing, Float64}}= A.W *_fill_in(A.b, B)
_fill_in(b, B) =_fill_mask(b, B)[1]
functionrrule(::typeof(_fill_in), b, B)
X, m =_fill_mask(b, B)
X, Δ -> (NO_FIELDS, @thunk(_fill_in_db(Δ, m)), @thunk(_fill_in_dB(Δ, .!m)))
end_fill_in_db(Δ, m) = (db =deepcopy(Δ); db[m] .=0; sum(db, dims=2))
_fill_in_dB(Δ, m) = (dB =deepcopy(Δ); dB[m] .=0; dB)
function_fill_mask(b, B)
m = .!ismissing.(B)
X =repeat(b, 1, size(B, 2))
X[m] = B[m]
X, m
end
and everything works as expected:
julia> B = [1.02.0; missing3.0]
2×2 Array{Union{Missing, Float64},2}:1.02.0missing3.0
julia> W, b =rand(2,2), rand(2)
([0.72718678751396440.9966562367102048; 0.75447825192684830.9148648966908355], [0.1202546404757574, 0.5564718739401007])
julia>MyMatrix(W, b) * B
2×2 Array{Float64,2}:1.28184.444341.263574.25355
julia>gradient(A ->sum(A*B), MyMatrix(W, b))
((W = [3.03.556471873940101; 3.03.556471873940101], b = [0.0; 1.9115211334010402]),)
julia>gradient(B ->sum(MyMatrix(W,b)*B), B)
([1.48166503944081261.4816650394408126; 0.01.9115211334010402],)
However if I change the struct definition to:
struct MyMatrix{T, U <:AbstractMatrix{T}, V <:AbstractVector{T}} <:AbstractMatrix{T}
W::U
b::Vend
this stops working and gives this error:
julia>gradient(B ->sum(MyMatrix(W,b)*B), B)
ERROR: MethodError: no method matching copy(::Missing)
...
julia>gradient(A ->sum(A*B), MyMatrix(W, b))
ERROR: MethodError: no method matching copy(::Missing)
...
Is it possible to create custom AbstractMatrix subtypes in this way or does Zygote make it impossible?
The text was updated successfully, but these errors were encountered:
The problem was that there is a definition of adjoint for multiplication of two matrices. I have tried to implement an adjoint specifically for * using rrule, which fails:
Is it possible with Zygote to build custom
AbstractMatrix
subtypes with custom adjoints? I would like to have a matrix, that would behave like a standard dense matrix, but before left-multiplying a data matrix it fills all missing elements in that matrix by values from a vector of parameters:and everything works as expected:
However if I change the struct definition to:
this stops working and gives this error:
Is it possible to create custom
AbstractMatrix
subtypes in this way or does Zygote make it impossible?The text was updated successfully, but these errors were encountered: