From 41f4fecb1047354ad1f01b8930de0e497cd7953a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Dec 2023 11:28:27 -0500 Subject: [PATCH] Correctness fix in LM --- Project.toml | 2 +- src/levenberg.jl | 57 ++++++++++++++++++++++++++++++++++------ src/utils.jl | 27 ------------------- test/23_test_problems.jl | 7 ++--- 4 files changed, 52 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 9891f3cfb..d334fd8ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.0.0" +version = "3.0.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/levenberg.jl b/src/levenberg.jl index 305613df2..95daa3084 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i algorithm for nonlinear least-squares minimization". Designed for large-scale and numerically-difficult nonlinear systems. -If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient -routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see -"Chapter 10: Implementation of the Levenberg-Marquardt Method" of -["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5). +### How to Choose the Linear Solver? + +There are 2 ways to perform the LM Step + + 1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver + 2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the + normal form for this) + +The second form tends to be more robust and can be solved using any Least Squares Solver. +If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form. +However, in most cases, this means losing structure in `J` which is not ideal. Note that +whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or +any such solver which converts the equation to normal form before solving. These don't use +cache efficiently and we already support the normal form natively. + +Additionally, note that the first form leads to a positive definite system, so we can use +more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the +problem is very well conditioned, then you might want to solve the normal form directly. ### Keyword Arguments @@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, T = eltype(u) fu = evaluate_f(prob, u) - fastls = !__needs_square_A(alg, u0) + fastls = prob isa NonlinearProblem && !__needs_square_A(alg, u0) if !fastls uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, @@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, if fastls if setindex_trait(cache.mat_tmp) === CanSetindex() copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J) - cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD + cache.mat_tmp[(length(cache.fu) + 1):end, :] .= sqrt.(cache.λ .* cache.DᵀD) else - cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD) + cache.mat_tmp = _vcat(cache.J, sqrt.(cache.λ .* cache.DᵀD)) end if setindex_trait(cache.rhs_tmp) === CanSetindex() cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu) @@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2)) # The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp - # NOTE: Don't pass `A`` in again, since we want to reuse the previous solve + # NOTE: Don't pass `A` in again, since we want to reuse the previous solve @bb cache.Jv = cache.J × vec(cache.v) Jv = _restructure(cache.fu_cache_2, cache.Jv) @bb @. cache.fu_cache_2 = (2 / cache.h) * ((cache.fu_cache_2 - cache.fu) / cache.h - Jv) @@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip, return nothing end +@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x) +@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector) + if setindex_trait(y.diag) === CanSetindex() + @. y.diag = max(y.diag, x) + return y + else + return Diagonal(max.(y.diag, x)) + end +end +@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix) + if setindex_trait(y.diag) === CanSetindex() + if fast_scalar_indexing(y.diag) + @inbounds for i in axes(x, 1) + y.diag[i] = max(y.diag[i], x[i, i]) + end + return y + else + idxs = diagind(x) + @.. broadcast=false y.diag=max(y.diag, @view(x[idxs])) + return y + end + else + idxs = diagind(x) + return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs]))) + end +end + function __reinit_internal!(cache::LevenbergMarquardtCache; termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...) abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol, diff --git a/src/utils.jl b/src/utils.jl index d312c21c9..075a1857a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -442,33 +442,6 @@ function __sum_JᵀJ!!(y, J) end end -@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x) -@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector) - if setindex_trait(y.diag) === CanSetindex() - @. y.diag = max(y.diag, x) - return y - else - return Diagonal(max.(y.diag, x)) - end -end -@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix) - if setindex_trait(y.diag) === CanSetindex() - if fast_scalar_indexing(y.diag) - @inbounds for i in axes(x, 1) - y.diag[i] = max(y.diag[i], x[i, i]) - end - return y - else - idxs = diagind(x) - @.. broadcast=false y.diag=max(y.diag, @view(x[idxs])) - return y - end - else - idxs = diagind(x) - return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs]))) - end -end - # Alpha for Initial Jacobian Guess # The values are somewhat different from SciPy, these were tuned to the 23 test problems @inline function __initial_inv_alpha(α::Number, u, fu, norm::F) where {F} diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index 591083937..641d273cb 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -39,7 +39,6 @@ end @testset "NewtonRaphson 23 Test Problems" begin alg_ops = (NewtonRaphson(),) - # dictionary with indices of test problems where method does not converge to small residual broken_tests = Dict(alg => Int[] for alg in alg_ops) broken_tests[alg_ops[1]] = [1, 6] @@ -54,7 +53,6 @@ end TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin), TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.NLsolve)) - # dictionary with indices of test problems where method does not converge to small residual broken_tests = Dict(alg => Int[] for alg in alg_ops) broken_tests[alg_ops[1]] = [6, 11, 21] broken_tests[alg_ops[2]] = [6, 11, 21] @@ -70,10 +68,9 @@ end alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1), LevenbergMarquardt(; linsolve = CholeskyFactorization())) - # dictionary with indices of test problems where method does not converge to small residual broken_tests = Dict(alg => Int[] for alg in alg_ops) - broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21] - broken_tests[alg_ops[2]] = [3, 6, 11, 17, 21] + broken_tests[alg_ops[1]] = [6, 11, 21] + broken_tests[alg_ops[2]] = [6, 11, 21] broken_tests[alg_ops[3]] = [6, 11, 21] test_on_library(problems, dicts, alg_ops, broken_tests)