Skip to content

Commit

Permalink
Trust Region Methods for NLLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 8, 2023
1 parent af5db27 commit cb21f74
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 28 deletions.
25 changes: 16 additions & 9 deletions docs/src/api/nonlinearsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,38 @@

These are the native solvers of NonlinearSolve.jl.

## Core Nonlinear Solvers
## Nonlinear Solvers

```@docs
NewtonRaphson
TrustRegion
PseudoTransient
DFSane
Broyden
Klement
```

## Polyalgorithms
## Nonlinear Least Squares Solvers

```@docs
NonlinearSolvePolyAlgorithm
FastShortcutNonlinearPolyalg
FastShortcutNLLSPolyalg
RobustMultiNewton
GaussNewton
```

## Nonlinear Least Squares Solvers
## Both Nonlinear & Nonlinear Least Squares Solvers

These solvers can be used for both nonlinear and nonlinear least squares problems.

```@docs
TrustRegion
LevenbergMarquardt
GaussNewton
```

## Polyalgorithms

```@docs
NonlinearSolvePolyAlgorithm
FastShortcutNonlinearPolyalg
FastShortcutNLLSPolyalg
RobustMultiNewton
```

## Radius Update Schemes for Trust Region (RadiusUpdateSchemes)
Expand Down
2 changes: 2 additions & 0 deletions docs/src/solvers/NonlinearLeastSquaresSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ falls back to a more robust algorithm (`LevenbergMarquardt`).
handling of sparse matrices via colored automatic differentiation and preconditioned
linear solvers. Designed for large-scale and numerically-difficult nonlinear least
squares problems.
- `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and
autodiff methods for high performance on large and sparse systems.

### SimpleNonlinearSolve.jl

Expand Down
31 changes: 22 additions & 9 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,30 +213,43 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end

# jvp fallback scalar
function __jacvec(uf, u; autodiff, kwargs...)
if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
function __gradient_operator(uf, u; autodiff, kwargs...)
if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote)

Check warning on line 217 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L216-L217

Added lines #L216 - L217 were not covered by tests
_ad = autodiff
autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),

Check warning on line 219 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L219

Added line #L219 was not covered by tests
AutoFiniteDiff())
@warn "$(_ad) not supported for JacVec. Using $(autodiff) instead."
if u isa Number
autodiff = number_ad

Check warning on line 222 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L221-L222

Added lines #L221 - L222 were not covered by tests
else
if isinplace(uf)
autodiff = AutoFiniteDiff()

Check warning on line 225 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L224-L225

Added lines #L224 - L225 were not covered by tests
else
autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(),

Check warning on line 227 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L227

Added line #L227 was not covered by tests
AutoFiniteDiff())
end
end
if _ad !== nothing && _ad !== autodiff
@warn "$(_ad) not supported for VecJac. Using $(autodiff) instead."

Check warning on line 232 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L231-L232

Added lines #L231 - L232 were not covered by tests
end
end
return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...)
return u isa Number ? GradientScalar(uf, u, autodiff) :

Check warning on line 235 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L235

Added line #L235 was not covered by tests
VecJac(uf, u; autodiff, kwargs...)
end

@concrete mutable struct JVPScalar
@concrete mutable struct GradientScalar
uf
u
autodiff
end

function Base.:*(jvp::JVPScalar, v::Number)
function Base.:*(jvp::GradientScalar, v::Number)

Check warning on line 245 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L245

Added line #L245 was not covered by tests
if jvp.autodiff isa AutoForwardDiff
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v)))

Check warning on line 248 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L248

Added line #L248 was not covered by tests
return ForwardDiff.extract_derivative(T, out)
elseif jvp.autodiff isa AutoFiniteDiff
J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
return J * v
return J

Check warning on line 252 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L252

Added line #L252 was not covered by tests
else
error("Only ForwardDiff & FiniteDiff is currently supported.")
end
Expand Down
22 changes: 12 additions & 10 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,14 @@ end
p3
p4
ϵ
jvp_operator # For Yuan
vjp_operator # For Yuan
stats::NLStats
tc_cache
trace
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...;
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg_::TrustRegion, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
Expand Down Expand Up @@ -317,7 +318,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
p3 = convert(floatType, 0.0)
p4 = convert(floatType, 0.0)
ϵ = convert(floatType, 1.0e-8)
jvp_operator = nothing
vjp_operator = nothing
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
p1 = convert(floatType, 0.5)
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
Expand All @@ -336,8 +337,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
p1 = convert(floatType, 2.0) # μ
p2 = convert(floatType, 1 / 6) # c5
p3 = convert(floatType, 6.0) # c6
jvp_operator = __jacvec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
@bb Jᵀf = jvp_operator × fu
vjp_operator = __gradient_operator(uf, u; fu,

Check warning on line 340 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L340

Added line #L340 was not covered by tests
autodiff = __get_nonsparse_ad(alg.vjp_autodiff))
@bb Jᵀf = vjp_operator × fu

Check warning on line 342 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L342

Added line #L342 was not covered by tests
initial_trust_radius = convert(trustType, p1 * internalnorm(Jᵀf))
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
step_threshold = convert(trustType, 0.0001)
Expand Down Expand Up @@ -366,7 +368,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, jvp_operator,
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, vjp_operator,
NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

Expand Down Expand Up @@ -479,7 +481,7 @@ function trust_region_step!(cache::TrustRegionCache)
cache.shrink_counter = 0
end

@bb cache.Jᵀf = cache.jvp_operator × vec(cache.fu)
@bb cache.Jᵀf = cache.vjp_operator × vec(cache.fu)

Check warning on line 484 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L484

Added line #L484 was not covered by tests
cache.trust_r = cache.p1 * cache.internalnorm(cache.Jᵀf)

cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true)
Expand Down Expand Up @@ -567,10 +569,10 @@ end

# FIXME: Reinit `JᵀJ` operator if `p` is changed
function __reinit_internal!(cache::TrustRegionCache; kwargs...)
if cache.jvp_operator !== nothing
cache.jvp_operator = __jacvec(cache.uf, cache.u; cache.fu,
if cache.vjp_operator !== nothing
cache.vjp_operator = __gradient_operator(cache.uf, cache.u; cache.fu,

Check warning on line 573 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L572-L573

Added lines #L572 - L573 were not covered by tests
autodiff = __get_nonsparse_ad(cache.alg.ad))
@bb cache.Jᵀf = cache.jvp_operator × cache.fu
@bb cache.Jᵀf = cache.vjp_operator × cache.fu

Check warning on line 575 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L575

Added line #L575 was not covered by tests
end
cache.loss = __trust_region_loss(cache, cache.fu)
cache.loss_new = cache.loss
Expand Down
6 changes: 6 additions & 0 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]

solvers = []
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()]
vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] :
Expand All @@ -46,6 +47,11 @@ append!(solvers,
LeastSquaresOptimJL(:dogleg),
nothing,
])
for radius_update_scheme in [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright,
RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
push!(solvers, TrustRegion(; radius_update_scheme))
end

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down

0 comments on commit cb21f74

Please sign in to comment.