diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 5fb7a3748..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -160,10 +160,3 @@ function rrule(::typeof(spdiagm), v::AbstractVector) end return spdiagm(v), spdiagm_pullback end - - -function _diagm_back(p, ȳ) - k, v = p - d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix - return Tangent{typeof(p)}(second = d) -end \ No newline at end of file