Skip to content

Commit

Permalink
fully fix v1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 31, 2023
1 parent 0709d5d commit 73dfa2c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 45 deletions.
46 changes: 10 additions & 36 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ EnumX.@enumx DefaultAlgorithmChoice begin
UMFPACKFactorization
KrylovJL_GMRES
GenericLUFactorization
RowMaximumGenericLUFactorization
RFLUFactorization
LDLtFactorization
BunchKaufmanFactorization
Expand All @@ -23,7 +22,7 @@ struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
end

mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17}
T13, T14, T15, T16}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -33,14 +32,13 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
UMFPACKFactorization::T7
KrylovJL_GMRES::T8
GenericLUFactorization::T9
RowMaximumGenericLUFactorization::T10
RFLUFactorization::T11
LDLtFactorization::T12
BunchKaufmanFactorization::T13
CHOLMODFactorization::T14
SVDFactorization::T15
CholeskyFactorization::T16
NormalCholeskyFactorization::T17
RFLUFactorization::T10
LDLtFactorization::T11
BunchKaufmanFactorization::T12
CHOLMODFactorization::T13
SVDFactorization::T14
CholeskyFactorization::T15
NormalCholeskyFactorization::T16
end

# Legacy fallback
Expand Down Expand Up @@ -182,23 +180,15 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(__conditioning(assump) === OperatorCondition.IllConditioned ||
__conditioning(assump) === OperatorCondition.WellConditioned)
if length(b) <= 10
if __conditioning(assump) === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.RowMaximumGenericLUFactorization
else
DefaultAlgorithmChoice.GenericLUFactorization
end
DefaultAlgorithmChoice.GenericLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
else
if __conditioning(assump) === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.RowMaximumGenericLUFactorization
else
DefaultAlgorithmChoice.GenericLUFactorization
end
DefaultAlgorithmChoice.GenericLUFactorization
end
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
Expand Down Expand Up @@ -234,12 +224,6 @@ end
function algchoice_to_alg(alg::Symbol)
if alg === :SVDFactorization
SVDFactorization(false, LinearAlgebra.QRIteration())
elseif alg === :RowMaximumGenericLUFactorization
@static if VERSION < v"1.7beta"
GenericLUFactorization(Val(true))
else
GenericLUFactorization(RowMaximum())
end
elseif alg === :LDLtFactorization
LDLtFactorization()
elseif alg === :LUFactorization
Expand Down Expand Up @@ -361,16 +345,6 @@ end
function defaultalg_symbol(::Type{T}) where {T}
Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
end

@static if VERSION < v"1.7beta"
function defaultalg_symbol(::Type{<:GenericLUFactorization{false}})
:RowMaximumGenericLUFactorization
end
else
function defaultalg_symbol(::Type{<:GenericLUFactorization{LinearAlgebra.RowMaximum}})
:RowMaximumGenericLUFactorization
end
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization

"""
Expand Down
64 changes: 63 additions & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
nothing
end

@static if VERSION < v"1.7-"
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
A::Union{Diagonal,SymTridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
end
end

## QRFactorization

struct QRFactorization{P} <: AbstractFactorization
Expand Down Expand Up @@ -149,6 +158,15 @@ function init_cacheval(alg::QRFactorization, A::AbstractSciMLOperator, b, u, Pl,
nothing
end

@static if VERSION < v"1.7-"
function init_cacheval(alg::QRFactorization,
A::Union{Diagonal,SymTridiagonal,Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
end
end

## CholeskyFactorization

struct CholeskyFactorization{P, P2} <: AbstractFactorization
Expand Down Expand Up @@ -213,6 +231,14 @@ function init_cacheval(alg::CholeskyFactorization, A::AbstractSciMLOperator, b,
nothing
end

@static if VERSION < v"1.7beta"
function init_cacheval(alg::CholeskyFactorization, A::Union{SymTridiagonal,Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
end
end

## LDLtFactorization

struct LDLtFactorization{T} <: AbstractFactorization
Expand Down Expand Up @@ -281,6 +307,15 @@ function init_cacheval(alg::SVDFactorization, A, b, u, Pl, Pr,
nothing
end

@static if VERSION < v"1.7-"
function init_cacheval(alg::SVDFactorization,
A::Union{Diagonal,SymTridiagonal,Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
end
end

## BunchKaufmanFactorization

Base.@kwdef struct BunchKaufmanFactorization <: AbstractFactorization
Expand Down Expand Up @@ -744,24 +779,42 @@ end

function init_cacheval(alg::RFLUFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
@show size(A), length(b)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv
end

function init_cacheval(alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
PREALLOCATED_LU, ipiv
end

function init_cacheval(alg::RFLUFactorization,
A::Union{AbstractSparseArray, AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

@static if VERSION < v"1.7-"
function init_cacheval(alg::RFLUFactorization,
A::Union{Diagonal,SymTridiagonal,Tridiagonal}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end
end

function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)
fact, ipiv = get_cacheval(cache, :RFLUFactorization)
if cache.isfresh
if length(ipiv) != min(size(A)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T))
cache.cacheval = (fact, ipiv)
cache.isfresh = false
Expand Down Expand Up @@ -820,6 +873,15 @@ function init_cacheval(alg::NormalCholeskyFactorization,
nothing
end

@static if VERSION < v"1.7-"
function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{Tridiagonal, SymTridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
end
end

function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
kwargs...)
A = cache.A
Expand Down
18 changes: 10 additions & 8 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using LinearSolve, LinearAlgebra, SparseArrays, Test
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
LinearSolve.DefaultAlgorithmChoice.RowMaximumGenericLUFactorization
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
@test LinearSolve.defaultalg(nothing, zeros(50)).alg ===
LinearSolve.DefaultAlgorithmChoice.RFLUFactorization
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.RowMaximumGenericLUFactorization
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
@test LinearSolve.defaultalg(LinearAlgebra.Diagonal(zeros(5)), zeros(5)).alg ===
LinearSolve.DefaultAlgorithmChoice.DiagonalFactorization

Expand All @@ -17,9 +17,11 @@ using LinearSolve, LinearAlgebra, SparseArrays, Test
@test LinearSolve.defaultalg(sprand(11000, 11000, 0.001), zeros(11000)).alg ===
LinearSolve.DefaultAlgorithmChoice.UMFPACKFactorization

# Test inference
A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
@inferred solve(prob)
@inferred init(prob, nothing)
@static if VERSION >= v"v1.7-"
# Test inference
A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
@inferred solve(prob)
@inferred init(prob, nothing)
end

0 comments on commit 73dfa2c

Please sign in to comment.