From e0e809dfa210dba8dc0b565adbc5090200185040 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 26 Sep 2023 22:04:41 +0200 Subject: [PATCH] Further simplify rules --- src/rulesets/LinearAlgebra/dense.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 463aa0975..e7a9b9414 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -416,7 +416,7 @@ end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) dz = reshape(unthunk(z̄), length(y), size(x)...) - x̄ = @thunk Ref(y') .* eachslice(dz; dims = (2, 3)) + x̄ = @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y))) ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) return NoTangent(), x̄, ȳ end @@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<: function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) - ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3)) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback @@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<: function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) - ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x))) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback