Skip to content

Commit

Permalink
Add projections
Browse files Browse the repository at this point in the history
  • Loading branch information
Simone Carlo Surace committed Sep 27, 2023
1 parent fff05c2 commit 649e797
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,12 @@ end
end

function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), length(y), length(x))
= @thunk conj.(dz' * y)
ȳ = @thunk dz * conj.(x)
= @thunk(project_x(conj.(dz' * y)))
ȳ = @thunk(project_y(dz * conj.(x)))
return NoTangent(), x̄, ȳ
end
return kron(x, y), kron_pullback
Expand Down

0 comments on commit 649e797

Please sign in to comment.