Skip to content

Commit

Permalink
Further simplify rules
Browse files Browse the repository at this point in the history
  • Loading branch information
Simone Carlo Surace committed Sep 26, 2023
1 parent 3ccbdfc commit e0e809d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ end
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), length(y), size(x)...)
= @thunk Ref(y') .* eachslice(dz; dims = (2, 3))
= @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y)))
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
return NoTangent(), x̄, ȳ
end
Expand All @@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3))
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
return NoTangent(), x̄, ȳ
end
return kron(x, y), kron_pullback
Expand All @@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x)))
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
return NoTangent(), x̄, ȳ
end
return kron(x, y), kron_pullback
Expand Down

0 comments on commit e0e809d

Please sign in to comment.