Skip to content

Commit

Permalink
refactored to purge name space of SparseArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
KlausC committed Aug 8, 2018
1 parent 70c8b73 commit 4307da6
Showing 1 changed file with 88 additions and 93 deletions.
181 changes: 88 additions & 93 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,11 @@ function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S
end

## triangular sparse handling
abstract type UnitDiagonal end
struct UnitDiagonalYes <:UnitDiagonal end
struct UnitDiagonalNo <:UnitDiagonal end
abstract type AdjointElement end
struct AdjointElementYes <:AdjointElement end
struct AdjointElementNo <:AdjointElement end

possible_adjoint(adj::AdjointElementYes, a ) = adjoint(a)
possible_adjoint(adj::AdjointElementNo, a ) = a
AdjointElement(::Adjoint) = AdjointElementYes()
AdjointElement(::Any) = AdjointElementNo()
UnitDiagonal(::UnitUpperTriangular) = UnitDiagonalYes()
UnitDiagonal(::UnitLowerTriangular) = UnitDiagonalYes()
UnitDiagonal(::Any) = UnitDiagonalNo()

possible_adjoint(adj::Bool, a::Real ) = a
possible_adjoint(adj::Bool, a ) = adj ? adjoint(a) : a

const UnitDiagonalTriangular = Union{UnitUpperTriangular,UnitLowerTriangular}

const LowerTriangularPlain{T} = Union{
LowerTriangular{T,<:SparseMatrixCSCUnion{T}},
Expand Down Expand Up @@ -280,8 +271,21 @@ const TriangularSparse{T} = Union{
LowerTriangularSparse{T}, UpperTriangularSparse{T}} where T

## triangular multipliers
function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
# forward substitution for UpperTriangular SparseCSC matrices
function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
@assert !has_offset_axes(A, B)
nrowB, ncolB = size(B, 1), size(B, 2)
ncol = LinearAlgebra.checksquare(A)
if nrowB != ncol
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
end
_lmul!(A, B)
end

# forward multiplication for UpperTriangular SparseCSC matrices
function _lmul!(U::UpperTriangularPlain, B::StridedVecOrMat)
A = U.data
unit = U isa UnitDiagonalTriangular

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -292,7 +296,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
for j = 1:nrowB
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit isa UnitDiagonalYes
done = unit

bj = B[joff + j]
for ii = i1:i2
Expand All @@ -301,7 +305,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
if jai < j
B[joff + jai] += aii * bj
elseif jai == j
if unit isa UnitDiagonalNo
if !unit
B[joff + j] *= aii
done = true
end
Expand All @@ -318,8 +322,11 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
B
end

function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
# backward substitution for LowerTriangular SparseCSC matrices
# backward multiplication for LowerTriangular SparseCSC matrices
function _lmul!(L::LowerTriangularPlain, B::StridedVecOrMat)
A = L.data
unit = L isa UnitDiagonalTriangular

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -330,7 +337,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
for j = nrowB:-1:1
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit isa UnitDiagonalYes
done = unit

bj = B[joff + j]
for ii = i2:-1:i1
Expand All @@ -339,7 +346,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
if jai > j
B[joff + jai] += aii * bj
elseif jai == j
if unit isa UnitDiagonalNo
if !unit
B[joff + j] *= aii
done = true
end
Expand All @@ -356,8 +363,12 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
B
end

function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
# forward substitution for adjoint and transpose of LowerTriangular CSC matrices
# forward multiplication for adjoint and transpose of LowerTriangular CSC matrices
function _lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat)
A = U.parent.data
unit = U.parent isa UnitDiagonalTriangular
adj = U isa Adjoint

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -370,7 +381,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
i1 = ia[j]
i2 = ia[j + 1] - 1
akku = Z
j0 = unit isa UnitDiagonalNo ? j : j + 1
j0 = !unit ? j : j + 1

# loop through column j of A - only structural non-zeros
for ii = i2:-1:i1
Expand All @@ -382,7 +393,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
break
end
end
if unit isa UnitDiagonalYes
if unit
akku += B[joff + j]
end
B[joff + j] = akku
Expand All @@ -392,8 +403,12 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
B
end

function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
# multiply with adjoint and transpose of LowerTriangular CSC matrices
# backward multiplication with adjoint and transpose of LowerTriangular CSC matrices
function _lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat)
A = L.parent.data
unit = L.parent isa UnitDiagonalTriangular
adj = L isa Adjoint

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -406,7 +421,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
i1 = ia[j]
i2 = ia[j + 1] - 1
akku = Z
j0 = unit isa UnitDiagonalNo ? j : j - 1
j0 = !unit ? j : j - 1

# loop through column j of A - only structural non-zeros
for ii = i1:i2
Expand All @@ -418,7 +433,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
break
end
end
if unit isa UnitDiagonalYes
if unit
akku += B[joff + j]
end
B[joff + j] = akku
Expand All @@ -428,31 +443,22 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
B
end

function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
## triangular solvers
function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
@assert !has_offset_axes(A, B)
nrowB, ncolB = size(B, 1), size(B, 2)
ncol = LinearAlgebra.checksquare(A)
if nrowB != ncol
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
end
_lmul!(A, B)
_ldiv!(A, B)
end

_lmul!(L::LowerTriangularPlain, B::StridedVecOrMat) =
bwdTriMul!(L.data, B, UnitDiagonal(L))

_lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat) =
_bwdTriMul!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L))

_lmul!(U::UpperTriangularPlain, B::StridedVecOrMat) =
fwdTriMul!(U.data, B, UnitDiagonal(U))

_lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat) =
_fwdTriMul!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U))

## triangular solvers
function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
# forward substitution for LowerTriangular CSC matrices
function _ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat)
A = L.data
unit = L isa UnitDiagonalTriangular

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -472,12 +478,12 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
bj = B[joff + j]
# check for zero pivot and divide with pivot
if jai == j
if unit isa UnitDiagonalNo
if !unit
bj /= aa[ii]
B[joff + j] = bj
end
ii += 1
elseif unit isa UnitDiagonalNo
elseif !unit
throw(LinearAlgebra.SingularException(j))
end

Expand All @@ -491,8 +497,11 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
B
end

function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
# backward substitution for UpperTriangular CSC matrices
function _ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat)
A = U.data
unit = U isa UnitDiagonalTriangular

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -512,12 +521,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
bj = B[joff + j]
# check for zero pivot and divide with pivot
if jai == j
if unit isa UnitDiagonalNo
if !unit
bj /= aa[ii]
B[joff + j] = bj
end
ii -= 1
elseif unit isa UnitDiagonalNo
elseif !unit
throw(LinearAlgebra.SingularException(j))
end

Expand All @@ -531,8 +540,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
B
end

function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
# forward substitution for adjoint and transpose of UpperTriangular CSC matrices
function _ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat)
A = L.parent.data
unit = L.parent isa UnitDiagonalTriangular
adj = L isa Adjoint

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -547,14 +560,14 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
done = false

# loop through column j of A - only structural non-zeros
for ip = i1:i2
i = ja[ip]
if i < j
aai = possible_adjoint(adj, aa[ip])
akku -= B[joff + i] * aai
elseif i == j
if unit isa UnitDiagonalNo
aai = possible_adjoint(adj, aa[ip])
for ii = i1:i2
jai = ja[ii]
if jai < j
aai = possible_adjoint(adj, aa[ii])
akku -= B[joff + jai] * aai
elseif jai == j
if !unit
aai = possible_adjoint(adj, aa[ii])
akku /= aai
end
done = true
Expand All @@ -563,7 +576,7 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
break
end
end
if !done && unit isa UnitDiagonalNo
if !done && !unit
throw(LinearAlgebra.SingularException(j))
end
B[joff + j] = akku
Expand All @@ -573,8 +586,12 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
B
end

function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
# backward substitution for adjoint and transpose of LowerTriangular CSC matrices
function _ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat)
A = U.parent.data
unit = U.parent isa UnitDiagonalTriangular
adj = U isa Adjoint

nrowB, ncolB = size(B, 1), size(B, 2)
aa = getnzval(A)
ja = getrowval(A)
Expand All @@ -589,14 +606,14 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
done = false

# loop through column j of A - only structural non-zeros
for ip = i2:-1:i1
i = ja[ip]
if i > j
aai = possible_adjoint(adj, aa[ip])
akku -= B[joff + i] * aai
elseif i == j
if unit isa UnitDiagonalNo
aai = possible_adjoint(adj, aa[ip])
for ii = i2:-1:i1
jai = ja[ii]
if jai > j
aai = possible_adjoint(adj, aa[ii])
akku -= B[joff + jai] * aai
elseif jai == j
if !unit
aai = possible_adjoint(adj, aa[ii])
akku /= aai
end
done = true
Expand All @@ -605,7 +622,7 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
break
end
end
if !done && unit isa UnitDiagonalNo
if !done && !unit
throw(LinearAlgebra.SingularException(j))
end
B[joff + j] = akku
Expand All @@ -615,28 +632,6 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
B
end

function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
@assert !has_offset_axes(A, B)
nrowB, ncolB = size(B, 1), size(B, 2)
ncol = LinearAlgebra.checksquare(A)
if nrowB != ncol
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
end
_ldiv!(A, B)
end

_ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat) =
fwdTriSolve!(L.data, B, UnitDiagonal(L))

_ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat) =
_fwdTriSolve!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L))

_ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat) =
bwdTriSolve!(U.data, B, UnitDiagonal(U))

_ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat) =
_bwdTriSolve!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U))

(\)(L::TriangularSparse, B::SparseMatrixCSC) = ldiv!(L, Array(B))
(*)(L::TriangularSparse, B::SparseMatrixCSC) = lmul!(L, Array(B))

Expand Down

0 comments on commit 4307da6

Please sign in to comment.