Skip to content

Commit

Permalink
Correctness fix in LM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 9, 2023
1 parent e3c929a commit 41f4fec
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.0.0"
version = "3.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
57 changes: 49 additions & 8 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
algorithm for nonlinear least-squares minimization". Designed for large-scale and
numerically-difficult nonlinear systems.
If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
"Chapter 10: Implementation of the Levenberg-Marquardt Method" of
["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
### How to Choose the Linear Solver?
There are 2 ways to perform the LM Step
1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver
2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the
normal form for this)
The second form tends to be more robust and can be solved using any Least Squares Solver.
If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form.
However, in most cases, this means losing structure in `J` which is not ideal. Note that
whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or
any such solver which converts the equation to normal form before solving. These don't use
cache efficiently and we already support the normal form natively.
Additionally, note that the first form leads to a positive definite system, so we can use
more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the
problem is very well conditioned, then you might want to solve the normal form directly.
### Keyword Arguments
Expand Down Expand Up @@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
T = eltype(u)
fu = evaluate_f(prob, u)

fastls = !__needs_square_A(alg, u0)
fastls = prob isa NonlinearProblem && !__needs_square_A(alg, u0)

if !fastls
uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
Expand Down Expand Up @@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
if fastls
if setindex_trait(cache.mat_tmp) === CanSetindex()
copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J)
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= sqrt.(cache.λ .* cache.DᵀD)
else
cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD)
cache.mat_tmp = _vcat(cache.J, sqrt.(cache.λ .* cache.DᵀD))
end
if setindex_trait(cache.rhs_tmp) === CanSetindex()
cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu)
Expand Down Expand Up @@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))

# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
# NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
@bb cache.Jv = cache.J × vec(cache.v)
Jv = _restructure(cache.fu_cache_2, cache.Jv)
@bb @. cache.fu_cache_2 = (2 / cache.h) * ((cache.fu_cache_2 - cache.fu) / cache.h - Jv)
Expand Down Expand Up @@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
return nothing
end

@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
if setindex_trait(y.diag) === CanSetindex()
@. y.diag = max(y.diag, x)
return y
else
return Diagonal(max.(y.diag, x))
end
end
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
if setindex_trait(y.diag) === CanSetindex()
if fast_scalar_indexing(y.diag)
@inbounds for i in axes(x, 1)
y.diag[i] = max(y.diag[i], x[i, i])
end
return y
else
idxs = diagind(x)
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
return y

Check warning on line 373 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L371-L373

Added lines #L371 - L373 were not covered by tests
end
else
idxs = diagind(x)
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))

Check warning on line 377 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L376-L377

Added lines #L376 - L377 were not covered by tests
end
end

function __reinit_internal!(cache::LevenbergMarquardtCache;
termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...)
abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol,
Expand Down
27 changes: 0 additions & 27 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,33 +442,6 @@ function __sum_JᵀJ!!(y, J)
end
end

@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
if setindex_trait(y.diag) === CanSetindex()
@. y.diag = max(y.diag, x)
return y
else
return Diagonal(max.(y.diag, x))
end
end
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
if setindex_trait(y.diag) === CanSetindex()
if fast_scalar_indexing(y.diag)
@inbounds for i in axes(x, 1)
y.diag[i] = max(y.diag[i], x[i, i])
end
return y
else
idxs = diagind(x)
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
return y
end
else
idxs = diagind(x)
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))
end
end

# Alpha for Initial Jacobian Guess
# The values are somewhat different from SciPy, these were tuned to the 23 test problems
@inline function __initial_inv_alpha::Number, u, fu, norm::F) where {F}
Expand Down
7 changes: 2 additions & 5 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ end
@testset "NewtonRaphson 23 Test Problems" begin
alg_ops = (NewtonRaphson(),)

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 6]

Expand All @@ -54,7 +53,6 @@ end
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin),
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.NLsolve))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [6, 11, 21]
broken_tests[alg_ops[2]] = [6, 11, 21]
Expand All @@ -70,10 +68,9 @@ end
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1),
LevenbergMarquardt(; linsolve = CholeskyFactorization()))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
broken_tests[alg_ops[2]] = [3, 6, 11, 17, 21]
broken_tests[alg_ops[1]] = [6, 11, 21]
broken_tests[alg_ops[2]] = [6, 11, 21]
broken_tests[alg_ops[3]] = [6, 11, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
Expand Down

0 comments on commit 41f4fec

Please sign in to comment.