diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 810c78b5a..003054cca 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y) return kron(x, y), kron(Δx, y) + kron(x, Δy) end -function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) - z = kron(x, y) +function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector) + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for i in eachindex(x) + xi = x[i] + for k in eachindex(y) + x̄[i] += y[k]' * z̄[m] + ȳ[k] += xi' * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + kron(x, y), kron_pullback +end +function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) function kron_pullback(z̄) x̄ = zero(x) ȳ = zero(y) @@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) xij = x[i,j] for k in eachindex(y) x̄[i, j] += y[k]' * z̄[m] - ȳ[k] += xij * z̄[m] + ȳ[k] += xij' * z̄[m] m += 1 end end NoTangent(), x̄, ȳ end - z, kron_pullback + kron(x, y), kron_pullback end function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) - z = kron(x, y) - function kron_pullback(z̄) x̄ = zero(x) ȳ = zero(y) @@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) xi = x[i] for k in axes(y,1) x̄[i] += y[k, l]' * z̄[m] - ȳ[k, l] += xi * z̄[m] + ȳ[k, l] += xi' * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + kron(x, y), kron_pullback +end + +function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix) + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1) + xij = x[i, j] + for k in axes(y,1) + x̄[i, j] += y[k, l]' * z̄[m] + ȳ[k, l] += xij' * z̄[m] m += 1 end end NoTangent(), x̄, ȳ end - z, kron_pullback + kron(x, y), kron_pullback end