Skip to content

Commit

Permalink
Reuse LU Factorization to check for singular matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 22, 2023
1 parent 00852f0 commit d2251ba
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 29 deletions.
3 changes: 3 additions & 0 deletions docs/src/api/nonlinearsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ These are the native solvers of NonlinearSolve.jl.
NewtonRaphson
TrustRegion
PseudoTransient
DFSane
GeneralBroyden
GeneralKlement
```

## Polyalgorithms
Expand Down
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ArrayInterface: restructure
import ForwardDiff

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable, issingular
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff: Dual
Expand Down
1 change: 0 additions & 1 deletion src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ See also the implementation in [SimpleNonlinearSolve.jl](https://github.com/SciM
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
algorithm. Defaults to `1000`.
"""

struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm
σ_min::T
σ_max::T
Expand Down
63 changes: 43 additions & 20 deletions src/klement.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
"""
GeneralKlement(; max_resets = 5, linsolve = LUFactorization(),
linesearch = LineSearch(), precs = DEFAULT_PRECS)
An implementation of `Klement` with line search, preconditioning and customizable linear
solves.
## Keyword Arguments
- `max_resets`: the maximum number of resets to perform. Defaults to `5`.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
"""
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
max_resets::Int
linsolve
precs
linesearch
singular_tolerance
end

function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
function GeneralKlement(; max_resets::Int = 5, linsolve = LUFactorization(),

Check warning on line 30 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L30

Added line #L30 was not covered by tests
linesearch = LineSearch(), precs = DEFAULT_PRECS)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)
return GeneralKlement(max_resets, linsolve, precs, linesearch)

Check warning on line 33 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end

@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand All @@ -27,7 +48,6 @@ end
Jᵀ²du
Jdu
resets
singular_tolerance
force_stop
maxiters::Int
internalnorm
Expand Down Expand Up @@ -60,11 +80,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlemen
linsolve_kwargs...)
end

singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :
eltype(u)(alg.singular_tolerance)

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,

Check warning on line 83 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L83

Added line #L83 was not covered by tests
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end
Expand All @@ -73,21 +90,23 @@ function perform_step!(cache::GeneralKlementCache{true})
@unpack u, fu, f, p, alg, J, linsolve, du = cache
T = eltype(J)

Check warning on line 91 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L89-L91

Added lines #L89 - L91 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

Check warning on line 93 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L93

Added line #L93 was not covered by tests

if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 99 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L95-L99

Added lines #L95 - L99 were not covered by tests
end
fact_done = false
fill!(J, zero(T))
J[diagind(J)] .= T(1)
cache.resets += 1

Check warning on line 104 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L101-L104

Added lines #L101 - L104 were not covered by tests
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),

Check warning on line 108 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L108

Added line #L108 was not covered by tests
b = -_vec(fu), linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 110 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L110

Added line #L110 was not covered by tests

# Line Search
Expand All @@ -108,7 +127,8 @@ function perform_step!(cache::GeneralKlementCache{true})
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
mul!(cache.Jdu, J, _vec(du))
cache.fu .= cache.fu2 .- cache.fu
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.fu .= _restructure(cache.fu,

Check warning on line 130 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L125-L130

Added lines #L125 - L130 were not covered by tests
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
cache.J_cache .*= J
mul!(cache.J_cache2, cache.J_cache, J)
Expand All @@ -123,23 +143,25 @@ function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
T = eltype(J)

Check warning on line 144 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L142-L144

Added lines #L142 - L144 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

Check warning on line 146 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L146

Added line #L146 was not covered by tests

if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 152 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L148-L152

Added lines #L148 - L152 were not covered by tests
end
cache.J = __init_identity_jacobian(u, fu)
fact_done = false
cache.J = __init_identity_jacobian(cache.u, fu)
cache.resets += 1

Check warning on line 156 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L154-L156

Added lines #L154 - L156 were not covered by tests
end

# u = u - J \ fu
if linsolve === nothing
cache.du = -fu / cache.J

Check warning on line 161 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L160-L161

Added lines #L160 - L161 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),

Check warning on line 163 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L163

Added line #L163 was not covered by tests
b = -_vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 165 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L165

Added line #L165 was not covered by tests
end

Expand All @@ -161,7 +183,8 @@ function perform_step!(cache::GeneralKlementCache{false})
cache.Jᵀ²du = cache.J_cache * cache.Jdu
cache.Jdu = J * _vec(cache.du)
cache.fu = cache.fu2 .- cache.fu
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.fu = _restructure(cache.fu,

Check warning on line 186 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L181-L186

Added lines #L181 - L186 were not covered by tests
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
cache.J = J .+ cache.J_cache

Check warning on line 189 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L188-L189

Added lines #L188 - L189 were not covered by tests

Expand Down
1 change: 0 additions & 1 deletion src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
end

"""
```julia
TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple,
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,
Expand Down
25 changes: 25 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,28 @@ function __init_identity_jacobian(u::StaticArray, fu)
return convert(MArray{Tuple{length(fu), length(u)}},

Check warning on line 221 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
Matrix{eltype(u)}(I, length(fu), length(u)))
end

# Check Singular Matrix
_issingular(x::Number) = iszero(x)
@generated function _issingular(x::T) where {T}
hasmethod(issingular, Tuple{T}) && return :(issingular(x))
return :(__issingular(x))

Check warning on line 229 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L226-L229

Added lines #L226 - L229 were not covered by tests
end
__issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(T)))
__issingular(x) = false ## If SciMLOperator and such

Check warning on line 232 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L231-L232

Added lines #L231 - L232 were not covered by tests

# If factorization is LU then perform that and update the linsolve cache
# else check if the matrix is singular
function _try_factorize_and_check_singular!(linsolve, X)
if linsolve.cacheval isa LU

Check warning on line 237 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L236-L237

Added lines #L236 - L237 were not covered by tests
# LU Factorization was used
linsolve.A = X
linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b,

Check warning on line 240 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L239-L240

Added lines #L239 - L240 were not covered by tests
linsolve.u)
linsolve.isfresh = false

Check warning on line 242 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L242

Added line #L242 was not covered by tests

return !issuccess(linsolve.cacheval), true

Check warning on line 244 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L244

Added line #L244 was not covered by tests
end
return _issingular(X), false

Check warning on line 246 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L246

Added line #L246 was not covered by tests
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

Check warning on line 248 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L248

Added line #L248 was not covered by tests
10 changes: 6 additions & 4 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,21 @@ end

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 3, 4, 5, 6, 8, 11, 12, 13, 14, 21]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 16, 21, 22]
broken_tests[alg_ops[3]] = [1, 2, 4, 5, 6, 8, 11, 12, 13, 14, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testset "GeneralKlement 23 Test Problems" begin
alg_ops = (GeneralKlement(),
GeneralKlement(; linesearch = BackTracking()))
GeneralKlement(; linesearch = BackTracking()),
GeneralKlement(; linesearch = HagerZhang()))

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22]
broken_tests[alg_ops[1]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[3]] = [1, 2, 5, 6, 11, 12, 13, 22]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
4 changes: 2 additions & 2 deletions test/matrix_resizing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ vecprob = NonlinearProblem(ff, vec(u0), p)
prob = NonlinearProblem(ff, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end

Expand All @@ -18,6 +18,6 @@ vecprob = NonlinearProblem(fiip, vec(u0), p)
prob = NonlinearProblem(fiip, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end

0 comments on commit d2251ba

Please sign in to comment.