From af5db27745b43787dbfb049926fa84def90943e0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Dec 2023 11:28:27 -0500 Subject: [PATCH] Correctness fix in LM --- Project.toml | 2 +- src/levenberg.jl | 57 +++++++++++++++++++++++++++++++++++++++++------- src/utils.jl | 27 ----------------------- 3 files changed, 50 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index 9891f3cfb..d334fd8ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.0.0" +version = "3.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/levenberg.jl b/src/levenberg.jl index 305613df2..95daa3084 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i algorithm for nonlinear least-squares minimization". Designed for large-scale and numerically-difficult nonlinear systems. -If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient -routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see -"Chapter 10: Implementation of the Levenberg-Marquardt Method" of -["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5). +### How to Choose the Linear Solver? + +There are 2 ways to perform the LM Step + + 1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver + 2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the + normal form for this) + +The second form tends to be more robust and can be solved using any Least Squares Solver. +If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form. +However, in most cases, this means losing structure in `J` which is not ideal. Note that +whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or +any such solver which converts the equation to normal form before solving. These don't use +cache efficiently and we already support the normal form natively. + +Additionally, note that the first form leads to a positive definite system, so we can use +more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the +problem is very well conditioned, then you might want to solve the normal form directly. ### Keyword Arguments @@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, T = eltype(u) fu = evaluate_f(prob, u) - fastls = !__needs_square_A(alg, u0) + fastls = prob isa NonlinearProblem && !__needs_square_A(alg, u0) if !fastls uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, @@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, if fastls if setindex_trait(cache.mat_tmp) === CanSetindex() copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J) - cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD + cache.mat_tmp[(length(cache.fu) + 1):end, :] .= sqrt.(cache.λ .* cache.DᵀD) else - cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD) + cache.mat_tmp = _vcat(cache.J, sqrt.(cache.λ .* cache.DᵀD)) end if setindex_trait(cache.rhs_tmp) === CanSetindex() cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu) @@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2)) # The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp - # NOTE: Don't pass `A`` in again, since we want to reuse the previous solve + # NOTE: Don't pass `A` in again, since we want to reuse the previous solve @bb cache.Jv = cache.J × vec(cache.v) Jv = _restructure(cache.fu_cache_2, cache.Jv) @bb @. cache.fu_cache_2 = (2 / cache.h) * ((cache.fu_cache_2 - cache.fu) / cache.h - Jv) @@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, return nothing end +@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x) +@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector) + if setindex_trait(y.diag) === CanSetindex() + @. y.diag = max(y.diag, x) + return y + else + return Diagonal(max.(y.diag, x)) + end +end +@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix) + if setindex_trait(y.diag) === CanSetindex() + if fast_scalar_indexing(y.diag) + @inbounds for i in axes(x, 1) + y.diag[i] = max(y.diag[i], x[i, i]) + end + return y + else + idxs = diagind(x) + @.. broadcast=false y.diag=max(y.diag, @view(x[idxs])) + return y + end + else + idxs = diagind(x) + return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs]))) + end +end + function __reinit_internal!(cache::LevenbergMarquardtCache; termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...) abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol, diff --git a/src/utils.jl b/src/utils.jl index d312c21c9..075a1857a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -442,33 +442,6 @@ function __sum_JᵀJ!!(y, J) end end -@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x) -@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector) - if setindex_trait(y.diag) === CanSetindex() - @. y.diag = max(y.diag, x) - return y - else - return Diagonal(max.(y.diag, x)) - end -end -@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix) - if setindex_trait(y.diag) === CanSetindex() - if fast_scalar_indexing(y.diag) - @inbounds for i in axes(x, 1) - y.diag[i] = max(y.diag[i], x[i, i]) - end - return y - else - idxs = diagind(x) - @.. broadcast=false y.diag=max(y.diag, @view(x[idxs])) - return y - end - else - idxs = diagind(x) - return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs]))) - end -end - # Alpha for Initial Jacobian Guess # The values are somewhat different from SciPy, these were tuned to the 23 test problems @inline function __initial_inv_alpha(α::Number, u, fu, norm::F) where {F}