Skip to content

Commit

Permalink
Correctness fix in LM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 8, 2023
1 parent b7a8bcb commit af5db27
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.0.0"
version = "3.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
57 changes: 49 additions & 8 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
algorithm for nonlinear least-squares minimization". Designed for large-scale and
numerically-difficult nonlinear systems.
If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
"Chapter 10: Implementation of the Levenberg-Marquardt Method" of
["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
### How to Choose the Linear Solver?
There are 2 ways to perform the LM Step
1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver
2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the
normal form for this)
The second form tends to be more robust and can be solved using any Least Squares Solver.
If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form.
However, in most cases, this means losing structure in `J` which is not ideal. Note that
whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or
any such solver which converts the equation to normal form before solving. These don't use
cache efficiently and we already support the normal form natively.
Additionally, note that the first form leads to a positive definite system, so we can use
more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the
problem is very well conditioned, then you might want to solve the normal form directly.
### Keyword Arguments
Expand Down Expand Up @@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
T = eltype(u)
fu = evaluate_f(prob, u)

fastls = !__needs_square_A(alg, u0)
fastls = prob isa NonlinearProblem && !__needs_square_A(alg, u0)

if !fastls
uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
Expand Down Expand Up @@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
if fastls
if setindex_trait(cache.mat_tmp) === CanSetindex()
copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J)
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= sqrt.(cache.λ .* cache.DᵀD)
else
cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD)
cache.mat_tmp = _vcat(cache.J, sqrt.(cache.λ .* cache.DᵀD))

Check warning on line 272 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L272

Added line #L272 was not covered by tests
end
if setindex_trait(cache.rhs_tmp) === CanSetindex()
cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu)
Expand Down Expand Up @@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))

# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
# NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
@bb cache.Jv = cache.J × vec(cache.v)
Jv = _restructure(cache.fu_cache_2, cache.Jv)
@bb @. cache.fu_cache_2 = (2 / cache.h) * ((cache.fu_cache_2 - cache.fu) / cache.h - Jv)
Expand Down Expand Up @@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
return nothing
end

@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
if setindex_trait(y.diag) === CanSetindex()

Check warning on line 356 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L355-L356

Added lines #L355 - L356 were not covered by tests
@. y.diag = max(y.diag, x)
return y
else
return Diagonal(max.(y.diag, x))

Check warning on line 360 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L360

Added line #L360 was not covered by tests
end
end
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
if setindex_trait(y.diag) === CanSetindex()
if fast_scalar_indexing(y.diag)

Check warning on line 365 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L364-L365

Added lines #L364 - L365 were not covered by tests
@inbounds for i in axes(x, 1)
y.diag[i] = max(y.diag[i], x[i, i])
end
return y
else
idxs = diagind(x)
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
return y

Check warning on line 373 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L371-L373

Added lines #L371 - L373 were not covered by tests
end
else
idxs = diagind(x)
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))

Check warning on line 377 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L376-L377

Added lines #L376 - L377 were not covered by tests
end
end

function __reinit_internal!(cache::LevenbergMarquardtCache;
termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...)
abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol,
Expand Down
27 changes: 0 additions & 27 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,33 +442,6 @@ function __sum_JᵀJ!!(y, J)
end
end

@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
if setindex_trait(y.diag) === CanSetindex()
@. y.diag = max(y.diag, x)
return y
else
return Diagonal(max.(y.diag, x))
end
end
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
if setindex_trait(y.diag) === CanSetindex()
if fast_scalar_indexing(y.diag)
@inbounds for i in axes(x, 1)
y.diag[i] = max(y.diag[i], x[i, i])
end
return y
else
idxs = diagind(x)
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
return y
end
else
idxs = diagind(x)
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))
end
end

# Alpha for Initial Jacobian Guess
# The values are somewhat different from SciPy, these were tuned to the 23 test problems
@inline function __initial_inv_alpha::Number, u, fu, norm::F) where {F}
Expand Down

0 comments on commit af5db27

Please sign in to comment.