Skip to content

Commit

Permalink
[WIP] Optimizing {+,-,*} for structured matrices (#28883)
Browse files Browse the repository at this point in the history
* added sparse multiplication and division for triangular matrices. Fix #28451

* merge with master

* merge with master 2

* fixed symtridiagonal + bidiagonal

* improved find diagonal part

* refactored to purge name space of SparseArrays

* additional test cases and bug fix

* specializing some structured matrix operations

* added constructors for Triangular(::Diagonal). Removed redundant code from binops of special.jl so that broadcasting takes over. Cleaned up some of the tests for special.jl

* fix whitespace

* actually fixed whitespace

* fixed a typo in Diagonal*Bi/Tridiag. Changed the multiplication methods to more explicit constructors so that matrices with BigFloat dont error

* fixed bidiag+/-diag speed regression

* fixed +/- regressions for the other structured matrix types (bidiag, tridiag, symtridiag, diag)

* Revert "merged with master"

This reverts commit 3a58908, reversing
changes made to 0facd1d.

* Removing the speedups for sparse matrix multiplication and division. These should go in another PR so this one can be merged more quickly.

Revert "added sparse multiplication and division for triangular matrices. Fix #28451"

This reverts commit 11c1d1d.

* Revert "additional test cases and bug fix"

This reverts commit 21592db.

* reverting sparse changes

* removing extra whitespace and comments

* fixing BiTriSym*BiTriSym sparse eltype

* fixing the cases where we have two structured matrices and the resulting diagonals are of different types. This still fails when the representation is a range and we get a step size of 0

* Fixes the issue where we try to add structured matrices and one has an eltype <: AbstractArray

See PR 27289

* remove adjoint and transpose methods that I never changed

* fixing tridiagonal constructor to save time/memory

* fixing bidiag * diag return type

* adding multiplication to binops tests
  • Loading branch information
mcognetta authored and andreasnoack committed Dec 11, 2018
1 parent 5c5489e commit 469fa36
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 45 deletions.
75 changes: 70 additions & 5 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,17 @@ function +(A::Bidiagonal, B::Bidiagonal)
if A.uplo == B.uplo
Bidiagonal(A.dv+B.dv, A.ev+B.ev, A.uplo)
else
Tridiagonal((A.uplo == 'U' ? (B.ev,A.dv+B.dv,A.ev) : (A.ev,A.dv+B.dv,B.ev))...)
newdv = A.dv+B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), newdv, typeof(newdv)(A.ev)) : (typeof(newdv)(A.ev), newdv, typeof(newdv)(B.ev)))...)
end
end

function -(A::Bidiagonal, B::Bidiagonal)
if A.uplo == B.uplo
Bidiagonal(A.dv-B.dv, A.ev-B.ev, A.uplo)
else
Tridiagonal((A.uplo == 'U' ? (-B.ev,A.dv-B.dv,A.ev) : (A.ev,A.dv-B.dv,-B.ev))...)
newdv = A.dv-B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, typeof(newdv)(A.ev)) : (typeof(newdv)(A.ev), newdv, typeof(newdv)(-B.ev)))...)
end
end

Expand Down Expand Up @@ -489,9 +491,72 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
end

const SpecialMatrix = Union{Bidiagonal,SymTridiagonal,Tridiagonal}
# to avoid ambiguity warning, but shouldn't be necessary
*(A::AbstractTriangular, B::SpecialMatrix) = Array(A) * Array(B)
*(A::SpecialMatrix, B::SpecialMatrix) = Array(A) * Array(B)

function *(A::AbstractTriangular, B::Union{SymTridiagonal, Tridiagonal})
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end

function *(A::UpperTriangular, B::Bidiagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
if B.uplo == 'U'
A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B)
else
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end
end

function *(A::LowerTriangular, B::Bidiagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
if B.uplo == 'L'
A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B)
else
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end
end

function *(A::Union{SymTridiagonal, Tridiagonal}, B::AbstractTriangular)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end

function *(A::Bidiagonal, B::UpperTriangular)
TS = promote_op(matprod, eltype(A), eltype(B))
if A.uplo == 'U'
A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B)
else
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end
end

function *(A::Bidiagonal, B::LowerTriangular)
TS = promote_op(matprod, eltype(A), eltype(B))
if A.uplo == 'L'
A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B)
else
A_mul_B_td!(zeros(TS, size(A)...), A, B)
end
end

function *(A::Bidiagonal, B::Diagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(similar(A, TS), A, B)
end

function *(A::Diagonal, B::BiTri)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(similar(B, TS), A, B)
end

function *(A::Diagonal, B::SymTridiagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B)
end

function *(A::SymTridiagonal, B::Diagonal)
TS = promote_op(matprod, eltype(A), eltype(B))
A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B)
end

#Generic multiplication
*(A::Bidiagonal{T}, B::AbstractVector{T}) where {T} = *(Array(A), B)
Expand Down
242 changes: 203 additions & 39 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ SymTridiagonal(A::AbstractTriangular) = SymTridiagonal(Tridiagonal(A))
Tridiagonal(A::AbstractTriangular) =
isbanded(A, -1, 1) ? Tridiagonal(diag(A, -1), diag(A, 0), diag(A, 1)) : # is tridiagonal
throw(ArgumentError("matrix cannot be represented as Tridiagonal"))

UpperTriangular(A::Bidiagonal) =
A.uplo == 'U' ? UpperTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as UpperTriangular"))
LowerTriangular(A::Bidiagonal) =
A.uplo == 'L' ? LowerTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as LowerTriangular"))

const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular}
const PossibleTriangularMatrix = Union{Diagonal, Bidiagonal, AbstractTriangular}

convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
Expand All @@ -67,6 +73,9 @@ convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m :
convert(T::Type{<:LowerTriangular}, m::Union{LowerTriangular,UnitLowerTriangular}) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::Union{UpperTriangular,UnitUpperTriangular}) = m isa T ? m : T(m)

convert(T::Type{<:LowerTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)

# Constructs two method definitions taking into account (assumed) commutativity
# e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining
# f(x::S, y::T) where {S,T} = x+y
Expand All @@ -80,51 +89,206 @@ macro commutative(myexpr)
end

for op in (:+, :-)
SpecialMatrices = [:Diagonal, :Bidiagonal, :Tridiagonal, :Matrix]
for (idx, matrixtype1) in enumerate(SpecialMatrices) # matrixtype1 is the sparser matrix type
for matrixtype2 in SpecialMatrices[idx+1:end] # matrixtype2 is the denser matrix type
@eval begin # TODO quite a few of these conversions are NOT defined
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
for (matrixtype, uplo, converttype) in ((:UpperTriangular, 'U', :UpperTriangular),
(:UnitUpperTriangular, 'U', :UpperTriangular),
(:LowerTriangular, 'L', :LowerTriangular),
(:UnitLowerTriangular, 'L', :LowerTriangular))
@eval begin
function ($op)(A::$matrixtype, B::Bidiagonal)
if B.uplo == $uplo
($op)(A, convert($converttype, B))
else
($op).(A, B)
end
end
end
end

for matrixtype1 in (:SymTridiagonal,) # matrixtype1 is the sparser matrix type
for matrixtype2 in (:Tridiagonal, :Matrix) # matrixtype2 is the denser matrix type
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
function ($op)(A::Bidiagonal, B::$matrixtype)
if A.uplo == $uplo
($op)(convert($converttype, A), B)
else
($op).(A, B)
end
end
end
end
end

for matrixtype1 in (:Diagonal, :Bidiagonal) # matrixtype1 is the sparser matrix type
for matrixtype2 in (:SymTridiagonal,) # matrixtype2 is the denser matrix type
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
end
end
end
# specialized +/- for structured matrices. If these are removed, it falls
# back to broadcasting which has ~2-10x speed regressions.
# For the other structure matrix pairs, broadcasting works well.

for matrixtype1 in (:Diagonal,)
for (matrixtype2,matrixtype3) in ((:UpperTriangular,:UpperTriangular),
(:UnitUpperTriangular,:UpperTriangular),
(:LowerTriangular,:LowerTriangular),
(:UnitLowerTriangular,:LowerTriangular))
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(($matrixtype3)(A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, ($matrixtype3)(B))
end
end
end
for matrixtype in (:SymTridiagonal,:Tridiagonal,:Bidiagonal,:Matrix)
@eval begin
($op)(A::AbstractTriangular, B::($matrixtype)) = ($op)(copyto!(similar(parent(A)), A), B)
($op)(A::($matrixtype), B::AbstractTriangular) = ($op)(A, copyto!(similar(parent(B)), B))
end
end
# For structured matrix types with different non-zero diagonals the underlying
# representations must be promoted to the same type.
# For example, in Diagonal + Bidiagonal only the main diagonal is touched so
# the off diagonal could be a different type after the operation resulting in
# an error. See issue #28994

function (+)(A::Bidiagonal, B::Diagonal)
newdv = A.dv + B.diag
Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo)
end

function (-)(A::Bidiagonal, B::Diagonal)
newdv = A.dv - B.diag
Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo)
end

function (+)(A::Diagonal, B::Bidiagonal)
newdv = A.diag + B.dv
Bidiagonal(newdv, typeof(newdv)(B.ev), B.uplo)
end

function (-)(A::Diagonal, B::Bidiagonal)
newdv = A.diag-B.dv
Bidiagonal(newdv, typeof(newdv)(-B.ev), B.uplo)
end

function (+)(A::Diagonal, B::SymTridiagonal)
newdv = A.diag+B.dv
SymTridiagonal(A.diag+B.dv, typeof(newdv)(B.ev))
end

function (-)(A::Diagonal, B::SymTridiagonal)
newdv = A.diag-B.dv
SymTridiagonal(newdv, typeof(newdv)(-B.ev))
end

function (+)(A::SymTridiagonal, B::Diagonal)
newdv = A.dv+B.diag
SymTridiagonal(newdv, typeof(newdv)(A.ev))
end

function (-)(A::SymTridiagonal, B::Diagonal)
newdv = A.dv-B.diag
SymTridiagonal(newdv, typeof(newdv)(A.ev))
end

# this set doesn't have the aforementioned problem

+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev)
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev)
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du)


function (+)(A::Diagonal, B::Tridiagonal)
newdv = A.diag+B.d
Tridiagonal(typeof(newdv)(B.dl), newdv, typeof(newdv)(B.du))
end

function (-)(A::Diagonal, B::Tridiagonal)
newdv = A.diag-B.d
Tridiagonal(typeof(newdv)(-B.dl), newdv, typeof(newdv)(-B.du))
end

function (+)(A::Tridiagonal, B::Diagonal)
newdv = A.d+B.diag
Tridiagonal(typeof(newdv)(A.dl), newdv, typeof(newdv)(A.du))
end

function (-)(A::Tridiagonal, B::Diagonal)
newdv = A.d-B.diag
Tridiagonal(typeof(newdv)(A.dl), newdv, typeof(newdv)(A.du))
end

function (+)(A::Bidiagonal, B::Tridiagonal)
newdv = A.dv+B.d
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.dl), newdv, A.ev+B.du) : (A.ev+B.dl, newdv, typeof(newdv)(B.du)))...)
end

function (-)(A::Bidiagonal, B::Tridiagonal)
newdv = A.dv-B.d
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.dl), newdv, A.ev-B.du) : (A.ev-B.dl, newdv, typeof(newdv)(-B.du)))...)
end

function (+)(A::Tridiagonal, B::Bidiagonal)
newdv = A.d+B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.dl), newdv, A.du+B.ev) : (A.dl+B.ev, newdv, typeof(newdv)(A.du)))...)
end

function (-)(A::Tridiagonal, B::Bidiagonal)
newdv = A.d-B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.dl), newdv, A.du-B.ev) : (A.dl-B.ev, newdv, typeof(newdv)(A.du)))...)
end

function (+)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv+B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...)
end

function (-)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv-B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...)
end

function (+)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv+B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...)
end

function (-)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv-B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...)
end

# fixing uniform scaling problems from #28994
# {<:Number} is required due to the test case from PR #27289 where eltype is a matrix.

function (+)(A::Tridiagonal{<:Number}, B::UniformScaling)
newd = A.d .+ B.λ
Tridiagonal(typeof(newd)(A.dl), newd, typeof(newd)(A.du))
end

function (+)(A::SymTridiagonal{<:Number}, B::UniformScaling)
newdv = A.dv .+ B.λ
SymTridiagonal(newdv, typeof(newdv)(A.ev))
end

function (+)(A::Bidiagonal{<:Number}, B::UniformScaling)
newdv = A.dv .+ B.λ
Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo)
end

function (+)(A::Diagonal{<:Number}, B::UniformScaling)
Diagonal(A.diag .+ B.λ)
end

function (+)(A::UniformScaling, B::Tridiagonal{<:Number})
newd = A.λ .+ B.d
Tridiagonal(typeof(newd)(B.dl), newd, typeof(newd)(B.du))
end

function (+)(A::UniformScaling, B::SymTridiagonal{<:Number})
newdv = A.λ .+ B.dv
SymTridiagonal(newdv, typeof(newdv)(B.ev))
end

function (+)(A::UniformScaling, B::Bidiagonal{<:Number})
newdv = A.λ .+ B.dv
Bidiagonal(newdv, typeof(newdv)(B.ev), B.uplo)
end

function (+)(A::UniformScaling, B::Diagonal{<:Number})
Diagonal(A.λ .+ B.diag)
end

function (-)(A::UniformScaling, B::Tridiagonal{<:Number})
newd = A.λ .- B.d
Tridiagonal(typeof(newd)(-B.dl), newd, typeof(newd)(-B.du))
end

function (-)(A::UniformScaling, B::SymTridiagonal{<:Number})
newdv = A.λ .- B.dv
SymTridiagonal(newdv, typeof(newdv)(-B.ev))
end

function (-)(A::UniformScaling, B::Bidiagonal{<:Number})
newdv = A.λ .- B.dv
Bidiagonal(newdv, typeof(newdv)(-B.ev), B.uplo)
end

function (-)(A::UniformScaling, B::Diagonal{<:Number})
Diagonal(A.λ .- B.diag)
end

rmul!(A::AbstractTriangular, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) =
Expand Down
1 change: 0 additions & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,6 @@ rdiv!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUp

## Some Triangular-Triangular cases. We might want to write tailored methods
## for these cases, but I'm not sure it is worth it.
(*)(A::Union{Tridiagonal,SymTridiagonal}, B::AbstractTriangular) = rmul!(Matrix(A), B)

for (f, f2!) in ((:*, :lmul!), (:\, :ldiv!))
@eval begin
Expand Down
Loading

0 comments on commit 469fa36

Please sign in to comment.