Skip to content

Commit

Permalink
Merge pull request #526 from j-fu/abstractsparsefactorization
Browse files Browse the repository at this point in the history
Introduce AbstractSparseFactorization and AbstractDenseFactorization
  • Loading branch information
ChrisRackauckas authored Aug 19, 2024
2 parents 43fc8d3 + 8646b3d commit 1c30db0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ using SciMLBase: _unwrap_val

abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractSparseFactorization <: AbstractFactorization end
abstract type AbstractDenseFactorization <: AbstractFactorization end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end

# Traits

needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractSparseFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

Expand Down
3 changes: 3 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ default_alias_b(::Any, ::Any, ::Any) = false
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true

default_alias_A(::AbstractSparseFactorization, ::Any, ::Any) = true
default_alias_b(::AbstractSparseFactorization, ::Any, ::Any) = true

DEFAULT_PRECS(A, p) = IdentityOperator(size(A)[1]), IdentityOperator(size(A)[2])

function __init_u0_from_Ab(A, b)
Expand Down
39 changes: 20 additions & 19 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Julia's built in `lu`. Equivalent to calling `lu!(A)`
- pivot: The choice of pivoting. Defaults to `LinearAlgebra.RowMaximum()`. The other choice is
`LinearAlgebra.NoPivot()`.
"""
Base.@kwdef struct LUFactorization{P} <: AbstractFactorization
Base.@kwdef struct LUFactorization{P} <: AbstractDenseFactorization
pivot::P = LinearAlgebra.RowMaximum()
reuse_symbolic::Bool = true
check_pattern::Bool = true # Check factorization re-use
Expand All @@ -70,7 +70,7 @@ Has low overhead and is good for small matrices.
- pivot: The choice of pivoting. Defaults to `LinearAlgebra.RowMaximum()`. The other choice is
`LinearAlgebra.NoPivot()`.
"""
struct GenericLUFactorization{P} <: AbstractFactorization
struct GenericLUFactorization{P} <: AbstractDenseFactorization
pivot::P
end

Expand Down Expand Up @@ -177,7 +177,7 @@ Julia's built in `qr`. Equivalent to calling `qr!(A)`.
- On CuMatrix, it will use a CUDA-accelerated QR from CuSolver.
- On BandedMatrix and BlockBandedMatrix, it will use a banded QR.
"""
struct QRFactorization{P} <: AbstractFactorization
struct QRFactorization{P} <: AbstractDenseFactorization
pivot::P
blocksize::Int
inplace::Bool
Expand Down Expand Up @@ -260,7 +260,7 @@ Julia's built in `cholesky`. Equivalent to calling `cholesky!(A)`.
- shift: the shift argument in CHOLMOD. Only used for sparse matrices.
- perm: the perm argument in CHOLMOD. Only used for sparse matrices.
"""
struct CholeskyFactorization{P, P2} <: AbstractFactorization
struct CholeskyFactorization{P, P2} <: AbstractDenseFactorization
pivot::P
tol::Int
shift::Float64
Expand Down Expand Up @@ -319,7 +319,7 @@ end

## LDLtFactorization

struct LDLtFactorization{T} <: AbstractFactorization
struct LDLtFactorization{T} <: AbstractDenseFactorization
shift::Float64
perm::T
end
Expand Down Expand Up @@ -361,7 +361,7 @@ Julia's built in `svd`. Equivalent to `svd!(A)`.
which by default is OpenBLAS but will use MKL if the user does `using MKL` in their
system.
"""
struct SVDFactorization{A} <: AbstractFactorization
struct SVDFactorization{A} <: AbstractDenseFactorization
full::Bool
alg::A
end
Expand Down Expand Up @@ -410,7 +410,7 @@ Only for Symmetric matrices.
- rook: whether to perform rook pivoting. Defaults to false.
"""
Base.@kwdef struct BunchKaufmanFactorization <: AbstractFactorization
Base.@kwdef struct BunchKaufmanFactorization <: AbstractDenseFactorization
rook::Bool = false
end

Expand Down Expand Up @@ -464,7 +464,7 @@ factorization API. Quoting from Base:
- fact_alg: the factorization algorithm to use. Defaults to `LinearAlgebra.factorize`, but can be
swapped to choices like `lu`, `qr`
"""
struct GenericFactorization{F} <: AbstractFactorization
struct GenericFactorization{F} <: AbstractDenseFactorization
fact_alg::F
end

Expand Down Expand Up @@ -781,7 +781,7 @@ patterns with “more structure”.
`A` has the same sparsity pattern as the previous `A`. If this algorithm is to
be used in a context where that assumption does not hold, set `reuse_symbolic=false`.
"""
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
Base.@kwdef struct UMFPACKFactorization <: AbstractSparseFactorization
reuse_symbolic::Bool = true
check_pattern::Bool = true # Check factorization re-use
end
Expand Down Expand Up @@ -860,7 +860,7 @@ A fast sparse LU-factorization which specializes on sparsity patterns with “le
`A` has the same sparsity pattern as the previous `A`. If this algorithm is to
be used in a context where that assumption does not hold, set `reuse_symbolic=false`.
"""
Base.@kwdef struct KLUFactorization <: AbstractFactorization
Base.@kwdef struct KLUFactorization <: AbstractSparseFactorization
reuse_symbolic::Bool = true
check_pattern::Bool = true
end
Expand Down Expand Up @@ -941,7 +941,7 @@ Only supports sparse matrices.
- shift: the shift argument in CHOLMOD.
- perm: the perm argument in CHOLMOD
"""
Base.@kwdef struct CHOLMODFactorization{T} <: AbstractFactorization
Base.@kwdef struct CHOLMODFactorization{T} <: AbstractSparseFactorization
shift::Float64 = 0.0
perm::T = nothing
end
Expand Down Expand Up @@ -993,7 +993,7 @@ implementation, usually outperforming OpenBLAS and MKL for smaller matrices
(<500x500), but currently optimized only for Base `Array` with `Float32` or `Float64`.
Additional optimization for complex matrices is in the works.
"""
struct RFLUFactorization{P, T} <: AbstractFactorization
struct RFLUFactorization{P, T} <: AbstractDenseFactorization
RFLUFactorization(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
end

Expand Down Expand Up @@ -1064,7 +1064,7 @@ be applied to well-conditioned matrices.
- pivot: Defaults to RowMaximum(), but can be NoPivot()
"""
struct NormalCholeskyFactorization{P} <: AbstractFactorization
struct NormalCholeskyFactorization{P} <: AbstractDenseFactorization
pivot::P
end

Expand Down Expand Up @@ -1152,7 +1152,7 @@ be applied to well-conditioned matrices.
- rook: whether to perform rook pivoting. Defaults to false.
"""
struct NormalBunchKaufmanFactorization <: AbstractFactorization
struct NormalBunchKaufmanFactorization <: AbstractDenseFactorization
rook::Bool
end

Expand Down Expand Up @@ -1189,7 +1189,7 @@ end
A special implementation only for solving `Diagonal` matrices fast.
"""
struct DiagonalFactorization <: AbstractFactorization end
struct DiagonalFactorization <: AbstractDenseFactorization end

function init_cacheval(alg::DiagonalFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -1225,7 +1225,7 @@ end
The FastLapackInterface.jl version of the LU factorization. Notably,
this version does not allow for choice of pivoting method.
"""
struct FastLUFactorization <: AbstractFactorization end
struct FastLUFactorization <: AbstractDenseFactorization end

function init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -1255,7 +1255,7 @@ end
The FastLapackInterface.jl version of the QR factorization.
"""
struct FastQRFactorization{P} <: AbstractFactorization
struct FastQRFactorization{P} <: AbstractDenseFactorization
pivot::P
blocksize::Int
end
Expand Down Expand Up @@ -1329,7 +1329,7 @@ dispatch to route around standard BLAS routines in the case e.g. of arbitrary-pr
floating point numbers or ForwardDiff.Dual.
This e.g. allows for Automatic Differentiation (AD) of a sparse-matrix solve.
"""
Base.@kwdef struct SparspakFactorization <: AbstractFactorization
Base.@kwdef struct SparspakFactorization <: AbstractSparseFactorization
reuse_symbolic::Bool = true
end

Expand Down Expand Up @@ -1388,7 +1388,8 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

for alg in InteractiveUtils.subtypes(AbstractFactorization)
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
InteractiveUtils.subtypes(AbstractSparseFactorization))
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down
27 changes: 27 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,36 @@ end
N = 10_000
A = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1))
u0 = ones(size(A, 2))

b = A * u0
B = MySparseMatrixCSC(A)
pr = LinearProblem(B, b)

# test default algorithn
@time "solve MySparseMatrixCSC" u=solve(pr)
@test norm(u - u0, Inf) < 1.0e-13

# test Krylov algorithm with reinit!
pr = LinearProblem(B, b)
solver=KrylovJL_CG()
cache=init(pr,solver,maxiters=1000,reltol=1.0e-10)
u=solve!(cache)
A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1))
b1=A1*u0
B1= MySparseMatrixCSC(A1)
@test norm(u - u0, Inf) < 1.0e-8
reinit!(cache; A=B1, b=b1)
u=solve!(cache)
@test norm(u - u0, Inf) < 1.0e-8

# test factorization with reinit!
pr = LinearProblem(B, b)
solver=SparspakFactorization()
cache=init(pr,solver)
u=solve!(cache)
@test norm(u - u0, Inf) < 1.0e-8
reinit!(cache; A=B1, b=b1)
u=solve!(cache)
@test norm(u - u0, Inf) < 1.0e-8

end
3 changes: 2 additions & 1 deletion test/resolve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test
using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization

for alg in subtypes(LinearSolve.AbstractFactorization)
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),InteractiveUtils.subtypes(AbstractSparseFactorization))
@show alg
if !(alg in [
DiagonalFactorization,
Expand Down

0 comments on commit 1c30db0

Please sign in to comment.