Skip to content

Commit

Permalink
making Bi/Tri/Sym times Diag output a structured matrix (not sparse) (#…
Browse files Browse the repository at this point in the history
…31889)

* making Bi/Tri/Sym times Diag output a structured matrix (not sparse)

* fix test errors/typos

* adding diag*bi/tri/sym and removing unused union

* add mul!(C,::BiTriSym, ::*Diagonal)

* fixing ambiguious methods and adjoint/transpose

* space

* test error

changing `transpose`/`adjoint` to `Transpose`/`Adjoint`

This will fail until #31889 is merged.

* remove transpose
  • Loading branch information
mcognetta authored and andreasnoack committed Jul 6, 2019
1 parent e6379da commit a060344
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 9 deletions.
82 changes: 73 additions & 9 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ end
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTri, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B)
Expand All @@ -347,12 +346,12 @@ mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractTriangular}, B::BiTriSym) = A
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractTriangular}, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym) = A_mul_B_td!(C, A, B)
mul!(C::AbstractVector, A::BiTri, B::AbstractVector) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
mul!(C::AbstractVecOrMat, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTri, B::Transpose{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B) # around bidiag line 330
mul!(C::AbstractMatrix, A::BiTri, B::Adjoint{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B)
mul!(C::AbstractVector, A::BiTri, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))
mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B) # around bidiag line 330
mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B)
mul!(C::AbstractVector, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))

function check_A_mul_B!_sizes(C, A, B)
require_one_based_indexing(C)
Expand Down Expand Up @@ -437,6 +436,39 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
C
end

function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B))
fill!(C, zero(eltype(C)))
Al = _diag(A, -1)
Ad = _diag(A, 0)
Au = _diag(A, 1)
Bd = B.diag
@inbounds begin
# first row of C
C[1,1] = A[1,1]*B[1,1]
C[1,2] = A[1,2]*B[2,2]
# second row of C
C[2,1] = A[2,1]*B[1,1]
C[2,2] = A[2,2]*B[2,2]
C[2,3] = A[2,3]*B[3,3]
for j in 3:n-2
C[j, j-1] = Al[j-1]*Bd[j-1]
C[j, j ] = Ad[j ]*Bd[j ]
C[j, j+1] = Au[j ]*Bd[j+1]
end
# row before last of C
C[n-1,n-2] = A[n-1,n-2]*B[n-2,n-2]
C[n-1,n-1] = A[n-1,n-1]*B[n-1,n-1]
C[n-1,n ] = A[n-1, n]*B[n ,n ]
# last row of C
C[n,n-1] = A[n,n-1]*B[n-1,n-1]
C[n,n ] = A[n,n ]*B[n, n ]
end # inbounds
C
end

function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
require_one_based_indexing(C)
require_one_based_indexing(B)
Expand Down Expand Up @@ -497,7 +529,39 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
C
end

const SpecialMatrix = Union{Bidiagonal,SymTridiagonal,Tridiagonal}
function A_mul_B_td!(C::AbstractMatrix, A::Diagonal, B::BiTriSym)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B))
fill!(C, zero(eltype(C)))
Ad = A.diag
Bl = _diag(B, -1)
Bd = _diag(B, 0)
Bu = _diag(B, 1)
@inbounds begin
# first row of C
C[1,1] = A[1,1]*B[1,1]
C[1,2] = A[1,1]*B[1,2]
# second row of C
C[2,1] = A[2,2]*B[2,1]
C[2,2] = A[2,2]*B[2,2]
C[2,3] = A[2,2]*B[2,3]
for j in 3:n-2
Ajj = Ad[j]
C[j, j-1] = Ajj*Bl[j-1]
C[j, j ] = Ajj*Bd[j]
C[j, j+1] = Ajj*Bu[j]
end
# row before last of C
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2]
C[n-1,n-1] = A[n-1,n-1]*B[n-1,n-1]
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ]
# last row of C
C[n,n-1] = A[n,n]*B[n,n-1]
C[n,n ] = A[n,n]*B[n,n ]
end # inbounds
C
end

function *(A::AbstractTriangular, B::Union{SymTridiagonal, Tridiagonal})
TS = promote_op(matprod, eltype(A), eltype(B))
Expand Down Expand Up @@ -545,7 +609,7 @@ function *(A::Bidiagonal, B::LowerTriangular)
end
end

function *(A::Bidiagonal, B::Diagonal)
function *(A::BiTri, B::Diagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(similar(A, TS), A, B)
end
Expand Down
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,31 @@ Random.seed!(1)
C = Matrix{elty}(undef, n, n)
Dia = Diagonal(T.dv)
@test mul!(C, Dia, T) Array(Dia)*Array(T)

# Issue #31870
# Bi/Tri/Sym times Diagonal
Diag = Diagonal(rand(elty, 10))
BidiagU = Bidiagonal(rand(elty, 10), rand(elty, 9), 'U')
BidiagL = Bidiagonal(rand(elty, 10), rand(elty, 9), 'L')
Tridiag = Tridiagonal(rand(elty, 9), rand(elty, 10), rand(elty, 9))
SymTri = SymTridiagonal(rand(elty, 10), rand(elty, 9))

mats = [Diag, BidiagU, BidiagL, Tridiag, SymTri]
for a in mats
for b in mats
@test a*b Matrix(a)*Matrix(b)
end
end

@test typeof(BidiagU*Diag) <: Bidiagonal
@test typeof(BidiagL*Diag) <: Bidiagonal
@test typeof(Tridiag*Diag) <: Tridiagonal
@test typeof(SymTri*Diag) <: Tridiagonal

@test typeof(BidiagU*Diag) <: Bidiagonal
@test typeof(Diag*BidiagL) <: Bidiagonal
@test typeof(Diag*Tridiag) <: Tridiagonal
@test typeof(Diag*SymTri) <: Tridiagonal
end

@test inv(T)*Tfull Matrix(I, n, n)
Expand Down

0 comments on commit a060344

Please sign in to comment.