Skip to content

Commit

Permalink
LinearAlgebra: Make kron with Diagonal matrices more efficient (#46463)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
eschnett and dkarrasch authored Aug 29, 2022
1 parent f5db687 commit c79b995
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
6 changes: 5 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,16 @@ _cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X
# ignored. However, some methods can fail if they read the entired ev
# rather than just the meaningful elements. This is a helper function
# for getting only the meaningful elements of ev. See #41089
_evview(S::SymTridiagonal) = @view S.ev[begin:length(S.dv) - 1]
_evview(S::SymTridiagonal) = @view S.ev[begin:begin + length(S.dv) - 2]

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))

# append a zero element / drop the last element
_pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B)
_droplast!(A) = deleteat!(A, lastindex(A))

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
# which is required by LAPACK but not SuiteSpase which allows real-complex solves in some cases. Hence,
Expand Down
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ convert(T::Type{<:Bidiagonal}, m::AbstractMatrix) = m isa T ? m : T(m)
similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo)
similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...)

function kron(A::Diagonal, B::Bidiagonal)
kdv = kron(diag(A), B.dv)
kev = _droplast!(kron(diag(A), _pushzero(B.ev)))
Bidiagonal(kdv, kev, B.uplo)
end

###################
# LAPACK routines #
Expand Down
15 changes: 14 additions & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,20 @@ end
return C
end

kron(A::Diagonal{<:Number}, B::Diagonal{<:Number}) = Diagonal(kron(A.diag, B.diag))
kron(A::Diagonal, B::Diagonal) = Diagonal(kron(A.diag, B.diag))

function kron(A::Diagonal, B::SymTridiagonal)
kdv = kron(diag(A), B.dv)
# We don't need to drop the last element
kev = kron(diag(A), _pushzero(_evview(B)))
SymTridiagonal(kdv, kev)
end
function kron(A::Diagonal, B::Tridiagonal)
kd = kron(diag(A), B.d)
kdl = _droplast!(kron(diag(A), _pushzero(B.dl)))
kdu = _droplast!(kron(diag(A), _pushzero(B.du)))
Tridiagonal(kdl, kd, kdu)
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
require_one_based_indexing(B)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,20 @@ end
@test kron(Ad, Ad).diag == kron([1, 2, 3], [1, 2, 3])
end

@testset "kron (issue #46456)" begin
A = Diagonal(randn(10))
BL = Bidiagonal(randn(10), randn(9), :L)
BU = Bidiagonal(randn(10), randn(9), :U)
C = SymTridiagonal(randn(10), randn(9))
Cl = SymTridiagonal(randn(10), randn(10))
D = Tridiagonal(randn(9), randn(10), randn(9))
@test kron(A, BL)::Bidiagonal == kron(Array(A), Array(BL))
@test kron(A, BU)::Bidiagonal == kron(Array(A), Array(BU))
@test kron(A, C)::SymTridiagonal == kron(Array(A), Array(C))
@test kron(A, Cl)::SymTridiagonal == kron(Array(A), Array(Cl))
@test kron(A, D)::Tridiagonal == kron(Array(A), Array(D))
end

@testset "svdvals and eigvals (#11120/#11247)" begin
D = Diagonal(Matrix{Float64}[randn(3,3), randn(2,2)])
@test sort([svdvals(D)...;], rev = true) svdvals([D.diag[1] zeros(3,2); zeros(2,3) D.diag[2]])
Expand Down

0 comments on commit c79b995

Please sign in to comment.