Skip to content

Commit

Permalink
Add rules and try to cover complex case
Browse files Browse the repository at this point in the history
  • Loading branch information
Simone Carlo Surace committed Sep 25, 2023
1 parent 73935c7 commit a3f0527
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y)
return kron(x, y), kron(Δx, y) + kron(x, Δy)
end

function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
z = kron(x, y)
function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector)
function kron_pullback(z̄)
= zero(x)
= zero(y)
m = firstindex(z̄)
@inbounds for i in eachindex(x)
xi = x[i]
for k in eachindex(y)
x̄[i] += y[k]' * z̄[m]
ȳ[k] += xi' * z̄[m]
m += 1
end
end
NoTangent(), x̄, ȳ
end
kron(x, y), kron_pullback
end

function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
function kron_pullback(z̄)
= zero(x)
= zero(y)
Expand All @@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
xij = x[i,j]
for k in eachindex(y)
x̄[i, j] += y[k]' * z̄[m]
ȳ[k] += xij * z̄[m]
ȳ[k] += xij' * z̄[m]
m += 1
end
end
NoTangent(), x̄, ȳ
end
z, kron_pullback
kron(x, y), kron_pullback
end

function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
z = kron(x, y)

function kron_pullback(z̄)
= zero(x)
= zero(y)
Expand All @@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
xi = x[i]
for k in axes(y,1)
x̄[i] += y[k, l]' * z̄[m]
ȳ[k, l] += xi * z̄[m]
ȳ[k, l] += xi' * z̄[m]
m += 1
end
end
NoTangent(), x̄, ȳ
end
kron(x, y), kron_pullback
end

function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix)
function kron_pullback(z̄)
= zero(x)
= zero(y)
m = firstindex(z̄)
@inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1)
xij = x[i, j]
for k in axes(y,1)
x̄[i, j] += y[k, l]' * z̄[m]
ȳ[k, l] += xij' * z̄[m]
m += 1
end
end
NoTangent(), x̄, ȳ
end
z, kron_pullback
kron(x, y), kron_pullback
end

0 comments on commit a3f0527

Please sign in to comment.