diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index b470a8d92789c..86ab4b78f50f0 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -235,20 +235,11 @@ function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S end ## triangular sparse handling -abstract type UnitDiagonal end -struct UnitDiagonalYes <:UnitDiagonal end -struct UnitDiagonalNo <:UnitDiagonal end -abstract type AdjointElement end -struct AdjointElementYes <:AdjointElement end -struct AdjointElementNo <:AdjointElement end - -possible_adjoint(adj::AdjointElementYes, a ) = adjoint(a) -possible_adjoint(adj::AdjointElementNo, a ) = a -AdjointElement(::Adjoint) = AdjointElementYes() -AdjointElement(::Any) = AdjointElementNo() -UnitDiagonal(::UnitUpperTriangular) = UnitDiagonalYes() -UnitDiagonal(::UnitLowerTriangular) = UnitDiagonalYes() -UnitDiagonal(::Any) = UnitDiagonalNo() + +possible_adjoint(adj::Bool, a::Real ) = a +possible_adjoint(adj::Bool, a ) = adj ? adjoint(a) : a + +const UnitDiagonalTriangular = Union{UnitUpperTriangular,UnitLowerTriangular} const LowerTriangularPlain{T} = Union{ LowerTriangular{T,<:SparseMatrixCSCUnion{T}}, @@ -280,8 +271,21 @@ const TriangularSparse{T} = Union{ LowerTriangularSparse{T}, UpperTriangularSparse{T}} where T ## triangular multipliers -function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal) -# forward substitution for UpperTriangular SparseCSC matrices +function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T + @assert !has_offset_axes(A, B) + nrowB, ncolB = size(B, 1), size(B, 2) + ncol = LinearAlgebra.checksquare(A) + if nrowB != ncol + throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows")) + end + _lmul!(A, B) +end + +# forward multiplication for UpperTriangular SparseCSC matrices +function _lmul!(U::UpperTriangularPlain, B::StridedVecOrMat) + A = U.data + unit = U isa UnitDiagonalTriangular + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -292,7 +296,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag for j = 1:nrowB i1 = ia[j] i2 = ia[j + 1] - 1 - done = unit isa UnitDiagonalYes + done = unit bj = B[joff + j] for ii = i1:i2 @@ -301,7 +305,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag if jai < j B[joff + jai] += aii * bj elseif jai == j - if unit isa UnitDiagonalNo + if !unit B[joff + j] *= aii done = true end @@ -318,8 +322,11 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag B end -function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal) -# backward substitution for LowerTriangular SparseCSC matrices +# backward multiplication for LowerTriangular SparseCSC matrices +function _lmul!(L::LowerTriangularPlain, B::StridedVecOrMat) + A = L.data + unit = L isa UnitDiagonalTriangular + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -330,7 +337,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag for j = nrowB:-1:1 i1 = ia[j] i2 = ia[j + 1] - 1 - done = unit isa UnitDiagonalYes + done = unit bj = B[joff + j] for ii = i2:-1:i1 @@ -339,7 +346,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag if jai > j B[joff + jai] += aii * bj elseif jai == j - if unit isa UnitDiagonalNo + if !unit B[joff + j] *= aii done = true end @@ -356,8 +363,12 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag B end -function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement) -# forward substitution for adjoint and transpose of LowerTriangular CSC matrices +# forward multiplication for adjoint and transpose of LowerTriangular CSC matrices +function _lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat) + A = U.parent.data + unit = U.parent isa UnitDiagonalTriangular + adj = U isa Adjoint + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -370,7 +381,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia i1 = ia[j] i2 = ia[j + 1] - 1 akku = Z - j0 = unit isa UnitDiagonalNo ? j : j + 1 + j0 = !unit ? j : j + 1 # loop through column j of A - only structural non-zeros for ii = i2:-1:i1 @@ -382,7 +393,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia break end end - if unit isa UnitDiagonalYes + if unit akku += B[joff + j] end B[joff + j] = akku @@ -392,8 +403,12 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia B end -function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement) -# multiply with adjoint and transpose of LowerTriangular CSC matrices +# backward multiplication with adjoint and transpose of LowerTriangular CSC matrices +function _lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat) + A = L.parent.data + unit = L.parent isa UnitDiagonalTriangular + adj = L isa Adjoint + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -406,7 +421,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia i1 = ia[j] i2 = ia[j + 1] - 1 akku = Z - j0 = unit isa UnitDiagonalNo ? j : j - 1 + j0 = !unit ? j : j - 1 # loop through column j of A - only structural non-zeros for ii = i1:i2 @@ -418,7 +433,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia break end end - if unit isa UnitDiagonalYes + if unit akku += B[joff + j] end B[joff + j] = akku @@ -428,31 +443,22 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia B end -function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T +## triangular solvers +function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T @assert !has_offset_axes(A, B) nrowB, ncolB = size(B, 1), size(B, 2) ncol = LinearAlgebra.checksquare(A) if nrowB != ncol throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows")) end - _lmul!(A, B) + _ldiv!(A, B) end -_lmul!(L::LowerTriangularPlain, B::StridedVecOrMat) = - bwdTriMul!(L.data, B, UnitDiagonal(L)) - -_lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat) = - _bwdTriMul!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L)) - -_lmul!(U::UpperTriangularPlain, B::StridedVecOrMat) = - fwdTriMul!(U.data, B, UnitDiagonal(U)) - -_lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat) = - _fwdTriMul!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U)) - -## triangular solvers -function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal) # forward substitution for LowerTriangular CSC matrices +function _ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat) + A = L.data + unit = L isa UnitDiagonalTriangular + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -472,12 +478,12 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi bj = B[joff + j] # check for zero pivot and divide with pivot if jai == j - if unit isa UnitDiagonalNo + if !unit bj /= aa[ii] B[joff + j] = bj end ii += 1 - elseif unit isa UnitDiagonalNo + elseif !unit throw(LinearAlgebra.SingularException(j)) end @@ -491,8 +497,11 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi B end -function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal) # backward substitution for UpperTriangular CSC matrices +function _ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat) + A = U.data + unit = U isa UnitDiagonalTriangular + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -512,12 +521,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi bj = B[joff + j] # check for zero pivot and divide with pivot if jai == j - if unit isa UnitDiagonalNo + if !unit bj /= aa[ii] B[joff + j] = bj end ii -= 1 - elseif unit isa UnitDiagonalNo + elseif !unit throw(LinearAlgebra.SingularException(j)) end @@ -531,8 +540,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi B end -function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement) # forward substitution for adjoint and transpose of UpperTriangular CSC matrices +function _ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat) + A = L.parent.data + unit = L.parent isa UnitDiagonalTriangular + adj = L isa Adjoint + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -547,14 +560,14 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD done = false # loop through column j of A - only structural non-zeros - for ip = i1:i2 - i = ja[ip] - if i < j - aai = possible_adjoint(adj, aa[ip]) - akku -= B[joff + i] * aai - elseif i == j - if unit isa UnitDiagonalNo - aai = possible_adjoint(adj, aa[ip]) + for ii = i1:i2 + jai = ja[ii] + if jai < j + aai = possible_adjoint(adj, aa[ii]) + akku -= B[joff + jai] * aai + elseif jai == j + if !unit + aai = possible_adjoint(adj, aa[ii]) akku /= aai end done = true @@ -563,7 +576,7 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD break end end - if !done && unit isa UnitDiagonalNo + if !done && !unit throw(LinearAlgebra.SingularException(j)) end B[joff + j] = akku @@ -573,8 +586,12 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD B end -function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement) # backward substitution for adjoint and transpose of LowerTriangular CSC matrices +function _ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat) + A = U.parent.data + unit = U.parent isa UnitDiagonalTriangular + adj = U isa Adjoint + nrowB, ncolB = size(B, 1), size(B, 2) aa = getnzval(A) ja = getrowval(A) @@ -589,14 +606,14 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD done = false # loop through column j of A - only structural non-zeros - for ip = i2:-1:i1 - i = ja[ip] - if i > j - aai = possible_adjoint(adj, aa[ip]) - akku -= B[joff + i] * aai - elseif i == j - if unit isa UnitDiagonalNo - aai = possible_adjoint(adj, aa[ip]) + for ii = i2:-1:i1 + jai = ja[ii] + if jai > j + aai = possible_adjoint(adj, aa[ii]) + akku -= B[joff + jai] * aai + elseif jai == j + if !unit + aai = possible_adjoint(adj, aa[ii]) akku /= aai end done = true @@ -605,7 +622,7 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD break end end - if !done && unit isa UnitDiagonalNo + if !done && !unit throw(LinearAlgebra.SingularException(j)) end B[joff + j] = akku @@ -615,28 +632,6 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD B end -function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T - @assert !has_offset_axes(A, B) - nrowB, ncolB = size(B, 1), size(B, 2) - ncol = LinearAlgebra.checksquare(A) - if nrowB != ncol - throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows")) - end - _ldiv!(A, B) -end - -_ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat) = - fwdTriSolve!(L.data, B, UnitDiagonal(L)) - -_ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat) = - _fwdTriSolve!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L)) - -_ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat) = - bwdTriSolve!(U.data, B, UnitDiagonal(U)) - -_ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat) = - _bwdTriSolve!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U)) - (\)(L::TriangularSparse, B::SparseMatrixCSC) = ldiv!(L, Array(B)) (*)(L::TriangularSparse, B::SparseMatrixCSC) = lmul!(L, Array(B))