From b7bd291f85304fa48da703cc3ef0c215041bdc4f Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Thu, 12 Oct 2023 11:49:41 -0700 Subject: [PATCH] remove duplicate _diagm_back definition --- src/rulesets/LinearAlgebra/structured.jl | 2 +- src/rulesets/SparseArrays/sparsematrix.jl | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 5fb7a3748..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -160,10 +160,3 @@ function rrule(::typeof(spdiagm), v::AbstractVector) end return spdiagm(v), spdiagm_pullback end - - -function _diagm_back(p, ȳ) - k, v = p - d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix - return Tangent{typeof(p)}(second = d) -end \ No newline at end of file