From 469fa360b88c68a1967b24c17ddfe015ccc1f623 Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 11 Dec 2018 21:06:54 +0900 Subject: [PATCH] [WIP] Optimizing {+,-,*} for structured matrices (#28883) * added sparse multiplication and division for triangular matrices. Fix #28451 * merge with master * merge with master 2 * fixed symtridiagonal + bidiagonal * improved find diagonal part * refactored to purge name space of SparseArrays * additional test cases and bug fix * specializing some structured matrix operations * added constructors for Triangular(::Diagonal). Removed redundant code from binops of special.jl so that broadcasting takes over. Cleaned up some of the tests for special.jl * fix whitespace * actually fixed whitespace * fixed a typo in Diagonal*Bi/Tridiag. Changed the multiplication methods to more explicit constructors so that matrices with BigFloat dont error * fixed bidiag+/-diag speed regression * fixed +/- regressions for the other structured matrix types (bidiag, tridiag, symtridiag, diag) * Revert "merged with master" This reverts commit 3a589088a848d7c3f90f77413654bfb2a3ed11fc, reversing changes made to 0facd1db5ca30a0c1e93d3299f9dc634ef50e4b4. * Removing the speedups for sparse matrix multiplication and division. These should go in another PR so this one can be merged more quickly. Revert "added sparse multiplication and division for triangular matrices. Fix #28451" This reverts commit 11c1d1d477635042cb41fb13125acc8301ca887d. * Revert "additional test cases and bug fix" This reverts commit 21592db0ed0dec1750db5a153a761133c0d2dd9e. * reverting sparse changes * removing extra whitespace and comments * fixing BiTriSym*BiTriSym sparse eltype * fixing the cases where we have two structured matrices and the resulting diagonals are of different types. This still fails when the representation is a range and we get a step size of 0 * Fixes the issue where we try to add structured matrices and one has an eltype <: AbstractArray See PR 27289 * remove adjoint and transpose methods that I never changed * fixing tridiagonal constructor to save time/memory * fixing bidiag * diag return type * adding multiplication to binops tests --- stdlib/LinearAlgebra/src/bidiag.jl | 75 +++++++- stdlib/LinearAlgebra/src/special.jl | 242 ++++++++++++++++++++---- stdlib/LinearAlgebra/src/triangular.jl | 1 - stdlib/LinearAlgebra/test/special.jl | 46 +++++ stdlib/SparseArrays/src/SparseArrays.jl | 7 + 5 files changed, 326 insertions(+), 45 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 24c3a015dd3c9..9f4282a645163 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -306,7 +306,8 @@ function +(A::Bidiagonal, B::Bidiagonal) if A.uplo == B.uplo Bidiagonal(A.dv+B.dv, A.ev+B.ev, A.uplo) else - Tridiagonal((A.uplo == 'U' ? (B.ev,A.dv+B.dv,A.ev) : (A.ev,A.dv+B.dv,B.ev))...) + newdv = A.dv+B.dv + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), newdv, typeof(newdv)(A.ev)) : (typeof(newdv)(A.ev), newdv, typeof(newdv)(B.ev)))...) end end @@ -314,7 +315,8 @@ function -(A::Bidiagonal, B::Bidiagonal) if A.uplo == B.uplo Bidiagonal(A.dv-B.dv, A.ev-B.ev, A.uplo) else - Tridiagonal((A.uplo == 'U' ? (-B.ev,A.dv-B.dv,A.ev) : (A.ev,A.dv-B.dv,-B.ev))...) + newdv = A.dv-B.dv + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, typeof(newdv)(A.ev)) : (typeof(newdv)(A.ev), newdv, typeof(newdv)(-B.ev)))...) end end @@ -489,9 +491,72 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) end const SpecialMatrix = Union{Bidiagonal,SymTridiagonal,Tridiagonal} -# to avoid ambiguity warning, but shouldn't be necessary -*(A::AbstractTriangular, B::SpecialMatrix) = Array(A) * Array(B) -*(A::SpecialMatrix, B::SpecialMatrix) = Array(A) * Array(B) + +function *(A::AbstractTriangular, B::Union{SymTridiagonal, Tridiagonal}) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(zeros(TS, size(A)...), A, B) +end + +function *(A::UpperTriangular, B::Bidiagonal) + TS = promote_op(matprod, eltype(A), eltype(B)) + if B.uplo == 'U' + A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B) + else + A_mul_B_td!(zeros(TS, size(A)...), A, B) + end +end + +function *(A::LowerTriangular, B::Bidiagonal) + TS = promote_op(matprod, eltype(A), eltype(B)) + if B.uplo == 'L' + A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B) + else + A_mul_B_td!(zeros(TS, size(A)...), A, B) + end +end + +function *(A::Union{SymTridiagonal, Tridiagonal}, B::AbstractTriangular) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(zeros(TS, size(A)...), A, B) +end + +function *(A::Bidiagonal, B::UpperTriangular) + TS = promote_op(matprod, eltype(A), eltype(B)) + if A.uplo == 'U' + A_mul_B_td!(UpperTriangular(zeros(TS, size(A)...)), A, B) + else + A_mul_B_td!(zeros(TS, size(A)...), A, B) + end +end + +function *(A::Bidiagonal, B::LowerTriangular) + TS = promote_op(matprod, eltype(A), eltype(B)) + if A.uplo == 'L' + A_mul_B_td!(LowerTriangular(zeros(TS, size(A)...)), A, B) + else + A_mul_B_td!(zeros(TS, size(A)...), A, B) + end +end + +function *(A::Bidiagonal, B::Diagonal) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(similar(A, TS), A, B) +end + +function *(A::Diagonal, B::BiTri) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(similar(B, TS), A, B) +end + +function *(A::Diagonal, B::SymTridiagonal) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B) +end + +function *(A::SymTridiagonal, B::Diagonal) + TS = promote_op(matprod, eltype(A), eltype(B)) + A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B) +end #Generic multiplication *(A::Bidiagonal{T}, B::AbstractVector{T}) where {T} = *(Array(A), B) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 54fd00c402746..a7c4c40b20582 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -56,9 +56,15 @@ SymTridiagonal(A::AbstractTriangular) = SymTridiagonal(Tridiagonal(A)) Tridiagonal(A::AbstractTriangular) = isbanded(A, -1, 1) ? Tridiagonal(diag(A, -1), diag(A, 0), diag(A, 1)) : # is tridiagonal throw(ArgumentError("matrix cannot be represented as Tridiagonal")) - +UpperTriangular(A::Bidiagonal) = + A.uplo == 'U' ? UpperTriangular{eltype(A), typeof(A)}(A) : + throw(ArgumentError("matrix cannot be represented as UpperTriangular")) +LowerTriangular(A::Bidiagonal) = + A.uplo == 'L' ? LowerTriangular{eltype(A), typeof(A)}(A) : + throw(ArgumentError("matrix cannot be represented as LowerTriangular")) const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular} +const PossibleTriangularMatrix = Union{Diagonal, Bidiagonal, AbstractTriangular} convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m) convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m) @@ -67,6 +73,9 @@ convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : convert(T::Type{<:LowerTriangular}, m::Union{LowerTriangular,UnitLowerTriangular}) = m isa T ? m : T(m) convert(T::Type{<:UpperTriangular}, m::Union{UpperTriangular,UnitUpperTriangular}) = m isa T ? m : T(m) +convert(T::Type{<:LowerTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m) +convert(T::Type{<:UpperTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m) + # Constructs two method definitions taking into account (assumed) commutativity # e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining # f(x::S, y::T) where {S,T} = x+y @@ -80,51 +89,206 @@ macro commutative(myexpr) end for op in (:+, :-) - SpecialMatrices = [:Diagonal, :Bidiagonal, :Tridiagonal, :Matrix] - for (idx, matrixtype1) in enumerate(SpecialMatrices) # matrixtype1 is the sparser matrix type - for matrixtype2 in SpecialMatrices[idx+1:end] # matrixtype2 is the denser matrix type - @eval begin # TODO quite a few of these conversions are NOT defined - ($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B) - ($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B)) + for (matrixtype, uplo, converttype) in ((:UpperTriangular, 'U', :UpperTriangular), + (:UnitUpperTriangular, 'U', :UpperTriangular), + (:LowerTriangular, 'L', :LowerTriangular), + (:UnitLowerTriangular, 'L', :LowerTriangular)) + @eval begin + function ($op)(A::$matrixtype, B::Bidiagonal) + if B.uplo == $uplo + ($op)(A, convert($converttype, B)) + else + ($op).(A, B) + end end - end - end - for matrixtype1 in (:SymTridiagonal,) # matrixtype1 is the sparser matrix type - for matrixtype2 in (:Tridiagonal, :Matrix) # matrixtype2 is the denser matrix type - @eval begin - ($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B) - ($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B)) + function ($op)(A::Bidiagonal, B::$matrixtype) + if A.uplo == $uplo + ($op)(convert($converttype, A), B) + else + ($op).(A, B) + end end end end +end - for matrixtype1 in (:Diagonal, :Bidiagonal) # matrixtype1 is the sparser matrix type - for matrixtype2 in (:SymTridiagonal,) # matrixtype2 is the denser matrix type - @eval begin - ($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B) - ($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B)) - end - end - end +# specialized +/- for structured matrices. If these are removed, it falls +# back to broadcasting which has ~2-10x speed regressions. +# For the other structure matrix pairs, broadcasting works well. - for matrixtype1 in (:Diagonal,) - for (matrixtype2,matrixtype3) in ((:UpperTriangular,:UpperTriangular), - (:UnitUpperTriangular,:UpperTriangular), - (:LowerTriangular,:LowerTriangular), - (:UnitLowerTriangular,:LowerTriangular)) - @eval begin - ($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(($matrixtype3)(A), B) - ($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, ($matrixtype3)(B)) - end - end - end - for matrixtype in (:SymTridiagonal,:Tridiagonal,:Bidiagonal,:Matrix) - @eval begin - ($op)(A::AbstractTriangular, B::($matrixtype)) = ($op)(copyto!(similar(parent(A)), A), B) - ($op)(A::($matrixtype), B::AbstractTriangular) = ($op)(A, copyto!(similar(parent(B)), B)) - end - end +# For structured matrix types with different non-zero diagonals the underlying +# representations must be promoted to the same type. +# For example, in Diagonal + Bidiagonal only the main diagonal is touched so +# the off diagonal could be a different type after the operation resulting in +# an error. See issue #28994 + +function (+)(A::Bidiagonal, B::Diagonal) + newdv = A.dv + B.diag + Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo) +end + +function (-)(A::Bidiagonal, B::Diagonal) + newdv = A.dv - B.diag + Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo) +end + +function (+)(A::Diagonal, B::Bidiagonal) + newdv = A.diag + B.dv + Bidiagonal(newdv, typeof(newdv)(B.ev), B.uplo) +end + +function (-)(A::Diagonal, B::Bidiagonal) + newdv = A.diag-B.dv + Bidiagonal(newdv, typeof(newdv)(-B.ev), B.uplo) +end + +function (+)(A::Diagonal, B::SymTridiagonal) + newdv = A.diag+B.dv + SymTridiagonal(A.diag+B.dv, typeof(newdv)(B.ev)) +end + +function (-)(A::Diagonal, B::SymTridiagonal) + newdv = A.diag-B.dv + SymTridiagonal(newdv, typeof(newdv)(-B.ev)) +end + +function (+)(A::SymTridiagonal, B::Diagonal) + newdv = A.dv+B.diag + SymTridiagonal(newdv, typeof(newdv)(A.ev)) +end + +function (-)(A::SymTridiagonal, B::Diagonal) + newdv = A.dv-B.diag + SymTridiagonal(newdv, typeof(newdv)(A.ev)) +end + +# this set doesn't have the aforementioned problem + ++(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev) +-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev) ++(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du) +-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du) + + +function (+)(A::Diagonal, B::Tridiagonal) + newdv = A.diag+B.d + Tridiagonal(typeof(newdv)(B.dl), newdv, typeof(newdv)(B.du)) +end + +function (-)(A::Diagonal, B::Tridiagonal) + newdv = A.diag-B.d + Tridiagonal(typeof(newdv)(-B.dl), newdv, typeof(newdv)(-B.du)) +end + +function (+)(A::Tridiagonal, B::Diagonal) + newdv = A.d+B.diag + Tridiagonal(typeof(newdv)(A.dl), newdv, typeof(newdv)(A.du)) +end + +function (-)(A::Tridiagonal, B::Diagonal) + newdv = A.d-B.diag + Tridiagonal(typeof(newdv)(A.dl), newdv, typeof(newdv)(A.du)) +end + +function (+)(A::Bidiagonal, B::Tridiagonal) + newdv = A.dv+B.d + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.dl), newdv, A.ev+B.du) : (A.ev+B.dl, newdv, typeof(newdv)(B.du)))...) +end + +function (-)(A::Bidiagonal, B::Tridiagonal) + newdv = A.dv-B.d + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.dl), newdv, A.ev-B.du) : (A.ev-B.dl, newdv, typeof(newdv)(-B.du)))...) +end + +function (+)(A::Tridiagonal, B::Bidiagonal) + newdv = A.d+B.dv + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.dl), newdv, A.du+B.ev) : (A.dl+B.ev, newdv, typeof(newdv)(A.du)))...) +end + +function (-)(A::Tridiagonal, B::Bidiagonal) + newdv = A.d-B.dv + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.dl), newdv, A.du-B.ev) : (A.dl-B.ev, newdv, typeof(newdv)(A.du)))...) +end + +function (+)(A::Bidiagonal, B::SymTridiagonal) + newdv = A.dv+B.dv + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...) +end + +function (-)(A::Bidiagonal, B::SymTridiagonal) + newdv = A.dv-B.dv + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...) +end + +function (+)(A::SymTridiagonal, B::Bidiagonal) + newdv = A.dv+B.dv + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...) +end + +function (-)(A::SymTridiagonal, B::Bidiagonal) + newdv = A.dv-B.dv + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...) +end + +# fixing uniform scaling problems from #28994 +# {<:Number} is required due to the test case from PR #27289 where eltype is a matrix. + +function (+)(A::Tridiagonal{<:Number}, B::UniformScaling) + newd = A.d .+ B.λ + Tridiagonal(typeof(newd)(A.dl), newd, typeof(newd)(A.du)) +end + +function (+)(A::SymTridiagonal{<:Number}, B::UniformScaling) + newdv = A.dv .+ B.λ + SymTridiagonal(newdv, typeof(newdv)(A.ev)) +end + +function (+)(A::Bidiagonal{<:Number}, B::UniformScaling) + newdv = A.dv .+ B.λ + Bidiagonal(newdv, typeof(newdv)(A.ev), A.uplo) +end + +function (+)(A::Diagonal{<:Number}, B::UniformScaling) + Diagonal(A.diag .+ B.λ) +end + +function (+)(A::UniformScaling, B::Tridiagonal{<:Number}) + newd = A.λ .+ B.d + Tridiagonal(typeof(newd)(B.dl), newd, typeof(newd)(B.du)) +end + +function (+)(A::UniformScaling, B::SymTridiagonal{<:Number}) + newdv = A.λ .+ B.dv + SymTridiagonal(newdv, typeof(newdv)(B.ev)) +end + +function (+)(A::UniformScaling, B::Bidiagonal{<:Number}) + newdv = A.λ .+ B.dv + Bidiagonal(newdv, typeof(newdv)(B.ev), B.uplo) +end + +function (+)(A::UniformScaling, B::Diagonal{<:Number}) + Diagonal(A.λ .+ B.diag) +end + +function (-)(A::UniformScaling, B::Tridiagonal{<:Number}) + newd = A.λ .- B.d + Tridiagonal(typeof(newd)(-B.dl), newd, typeof(newd)(-B.du)) +end + +function (-)(A::UniformScaling, B::SymTridiagonal{<:Number}) + newdv = A.λ .- B.dv + SymTridiagonal(newdv, typeof(newdv)(-B.ev)) +end + +function (-)(A::UniformScaling, B::Bidiagonal{<:Number}) + newdv = A.λ .- B.dv + Bidiagonal(newdv, typeof(newdv)(-B.ev), B.uplo) +end + +function (-)(A::UniformScaling, B::Diagonal{<:Number}) + Diagonal(A.λ .- B.diag) end rmul!(A::AbstractTriangular, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) = diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index c8933618ec2ba..5627cbb332d5e 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -1599,7 +1599,6 @@ rdiv!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUp ## Some Triangular-Triangular cases. We might want to write tailored methods ## for these cases, but I'm not sure it is worth it. -(*)(A::Union{Tridiagonal,SymTridiagonal}, B::AbstractTriangular) = rmul!(Matrix(A), B) for (f, f2!) in ((:*, :lmul!), (:\, :ldiv!)) @eval begin diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index 2092122c33458..90cfe24da894a 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -109,6 +109,52 @@ end @test Matrix(convert(Spectype,A) - D) ≈ Matrix(A - D) end end + + UpTri = UpperTriangular(rand(20,20)) + LoTri = LowerTriangular(rand(20,20)) + Diag = Diagonal(rand(20,20)) + Tridiag = Tridiagonal(rand(20, 20)) + UpBi = Bidiagonal(rand(20,20), :U) + LoBi = Bidiagonal(rand(20,20), :L) + Sym = SymTridiagonal(rand(20), rand(19)) + Dense = rand(20, 20) + mats = [UpTri, LoTri, Diag, Tridiag, UpBi, LoBi, Sym, Dense] + + for op in (+,-,*) + for A in mats + for B in mats + @test (op)(A, B) ≈ (op)(Matrix(A), Matrix(B)) ≈ Matrix((op)(A, B)) + end + end + end +end + +@testset "+ and - among structured matrices with different container types" begin + diag = 1:5 + offdiag = 1:4 + uniformscalingmats = [UniformScaling(3), UniformScaling(1.0), UniformScaling(3//5), UniformScaling(Complex{Float64}(1.3, 3.5))] + mats = [Diagonal(diag), Bidiagonal(diag, offdiag, 'U'), Bidiagonal(diag, offdiag, 'L'), Tridiagonal(offdiag, diag, offdiag), SymTridiagonal(diag, offdiag)] + for T in [ComplexF64, Int64, Rational{Int64}, Float64] + push!(mats, Diagonal(Vector{T}(diag))) + push!(mats, Bidiagonal(Vector{T}(diag), Vector{T}(offdiag), 'U')) + push!(mats, Bidiagonal(Vector{T}(diag), Vector{T}(offdiag), 'L')) + push!(mats, Tridiagonal(Vector{T}(offdiag), Vector{T}(diag), Vector{T}(offdiag))) + push!(mats, SymTridiagonal(Vector{T}(diag), Vector{T}(offdiag))) + end + + for op in (+,*) # to do: fix when operation is - and the matrix has a range as the underlying representation and we get a step size of 0. + for A in mats + for B in mats + @test (op)(A, B) ≈ (op)(Matrix(A), Matrix(B)) ≈ Matrix((op)(A, B)) + end + end + + for A in mats + for B in uniformscalingmats + @test (op)(A, B) ≈ (op)(Matrix(A), B) ≈ Matrix((op)(A, B)) + end + end + end end @testset "Triangular Types and QR" begin diff --git a/stdlib/SparseArrays/src/SparseArrays.jl b/stdlib/SparseArrays/src/SparseArrays.jl index 724b9865782a6..81f28b150cd4b 100644 --- a/stdlib/SparseArrays/src/SparseArrays.jl +++ b/stdlib/SparseArrays/src/SparseArrays.jl @@ -51,4 +51,11 @@ similar(D::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzero similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) +matprod(x, y) = x*y + x*y +const BiTriSym = Union{Bidiagonal,SymTridiagonal,Tridiagonal} +function *(A::BiTriSym, B::BiTriSym) + TS = promote_op(matprod, eltype(A), eltype(B)) + mul!(similar(A, TS, size(A)...), A, B) +end + end