Skip to content

Commit

Permalink
Add gauss newton
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2023
1 parent 0ffbf49 commit 3328077
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ pages = ["index.md",
"basics/FAQ.md"],
"Solver Summaries and Recommendations" => Any["solvers/NonlinearSystemSolvers.md",
"solvers/BracketingSolvers.md",
"solvers/SteadyStateSolvers.md"],
"solvers/SteadyStateSolvers.md",
"solvers/NonlinearLeastSquaresSolvers.md"],
"Detailed Solver APIs" => Any["api/nonlinearsolve.md",
"api/simplenonlinearsolve.md",
"api/minpack.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api/nonlinearsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ These are the native solvers of NonlinearSolve.jl.
```@docs
NewtonRaphson
TrustRegion
LevenbergMarquardt
GaussNewton
```

## Radius Update Schemes for Trust Region (RadiusUpdateSchemes)
Expand Down
28 changes: 28 additions & 0 deletions docs/src/solvers/NonlinearLeastSquaresSolvers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Nonlinear Least Squares Solvers

`solve(prob::NonlinearLeastSquaresProblem, alg; kwargs...)`

Solves the nonlinear least squares problem defined by `prob` using the algorithm
`alg`. If no algorithm is given, a default algorithm will be chosen.

## Recommended Methods

`LevenbergMarquardt` is a good choice for most problems.

## Full List of Methods

- `LevenbergMarquardt()`: An advanced Levenberg-Marquardt implementation with the
improvements suggested in the [paper](https://arxiv.org/abs/1201.5885) "Improvements to
the Levenberg-Marquardt algorithm for nonlinear least-squares minimization". Designed for
large-scale and numerically-difficult nonlinear systems.
- `GaussNewton()`: An advanced GaussNewton implementation with support for efficient
handling of sparse matrices via colored automatic differentiation and preconditioned
linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares
problems.

## Example usage

```julia
using NonlinearSolve
sol = solve(prob, LevenbergMarquardt())
```
4 changes: 4 additions & 0 deletions docs/src/solvers/NonlinearSystemSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ features, but have a bit of overhead on very small problems.
methods for high performance on large and sparse systems.
- `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and
autodiff methods for high performance on large and sparse systems.
- `LevenbergMarquardt()`: An advanced Levenberg-Marquardt implementation with the
improvements suggested in the [paper](https://arxiv.org/abs/1201.5885) "Improvements to
the Levenberg-Marquardt algorithm for nonlinear least-squares minimization". Designed for
large-scale and numerically-difficult nonlinear systems.

### SimpleNonlinearSolve.jl

Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("jacobian.jl")
include("ad.jl")

Expand Down Expand Up @@ -93,7 +94,7 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton

export LineSearch

Expand Down
165 changes: 165 additions & 0 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.
!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
then the Jacobian will not be constructed and instead direct Jacobian-vector products
`J*v` are computed using forward-mode automatic differentiation or finite differencing
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
the construction of the Jacobian.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
!!! warning
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
fu1
fu2
fu_new
du
p
uf
linsolve
J
JᵀJ
Jᵀf
jac_cache
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = f(u, p)
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))

JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
Jᵀf = zero(u)

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(cache.force_stop = true)
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache

cache.J = jacobian!!(cache.J, cache)
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(cache.force_stop = true)
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
12 changes: 6 additions & 6 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

# sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8)
# @test SciMLBase.successful_retcode(sol)
# @test norm(sol.resid) < 1e-6
sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6

# sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8)
# @test SciMLBase.successful_retcode(sol)
# @test norm(sol.resid) < 1e-6
sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6

sol = solve(prob_oop, LevenbergMarquardt(); maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
Expand Down

0 comments on commit 3328077

Please sign in to comment.