From d0ce07812d8e4fc79198f1a75e891ff14d591e07 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 28 Mar 2024 10:44:14 -0400 Subject: [PATCH] Override chainrules rrule --- src/chainrules.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index c3e1970..4bd7d2c 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -115,6 +115,14 @@ function CRC.rrule(::typeof(*), X::UniformBlockDiagonalMatrix{<:Union{Real, Comp return X * Y, ∇times end +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(*), + X::AbstractMatrix{<:Union{Real, Complex}}, + Y::UniformBlockDiagonalMatrix{<:Union{Real, Complex}}) + _f = @closure (x, y) -> dropdims( + batched_mul(reshape(x, :, 1, nbatches(x)), y.data); dims=1) + return CRC.rrule_via_ad(cfg, _f, X, Y) +end + # constructor function CRC.rrule(::Type{<:UniformBlockDiagonalMatrix}, data) function ∇UniformBlockDiagonalMatrix(Δ)