Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trust Region Methods for Nonlinear Least Squares Problem #311

Merged
merged 1 commit into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions docs/src/api/nonlinearsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,38 @@

These are the native solvers of NonlinearSolve.jl.

## Core Nonlinear Solvers
## Nonlinear Solvers

```@docs
NewtonRaphson
TrustRegion
PseudoTransient
DFSane
Broyden
Klement
```

## Polyalgorithms
## Nonlinear Least Squares Solvers

```@docs
NonlinearSolvePolyAlgorithm
FastShortcutNonlinearPolyalg
FastShortcutNLLSPolyalg
RobustMultiNewton
GaussNewton
```

## Nonlinear Least Squares Solvers
## Both Nonlinear & Nonlinear Least Squares Solvers

These solvers can be used for both nonlinear and nonlinear least squares problems.

```@docs
TrustRegion
LevenbergMarquardt
GaussNewton
```

## Polyalgorithms

```@docs
NonlinearSolvePolyAlgorithm
FastShortcutNonlinearPolyalg
FastShortcutNLLSPolyalg
RobustMultiNewton
```

## Radius Update Schemes for Trust Region (RadiusUpdateSchemes)
Expand Down
2 changes: 2 additions & 0 deletions docs/src/solvers/NonlinearLeastSquaresSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ falls back to a more robust algorithm (`LevenbergMarquardt`).
handling of sparse matrices via colored automatic differentiation and preconditioned
linear solvers. Designed for large-scale and numerically-difficult nonlinear least
squares problems.
- `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and
autodiff methods for high performance on large and sparse systems.

### SimpleNonlinearSolve.jl

Expand Down
31 changes: 22 additions & 9 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,30 +213,43 @@
end

# jvp fallback scalar
function __jacvec(uf, u; autodiff, kwargs...)
if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
function __gradient_operator(uf, u; autodiff, kwargs...)
if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote)
_ad = autodiff
autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
AutoFiniteDiff())
@warn "$(_ad) not supported for JacVec. Using $(autodiff) instead."
if u isa Number
autodiff = number_ad
else
if isinplace(uf)
autodiff = AutoFiniteDiff()
else
autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(),
AutoFiniteDiff())
end
end
if _ad !== nothing && _ad !== autodiff
@warn "$(_ad) not supported for VecJac. Using $(autodiff) instead."

Check warning on line 232 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L232

Added line #L232 was not covered by tests
end
end
return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...)
return u isa Number ? GradientScalar(uf, u, autodiff) :
VecJac(uf, u; autodiff, kwargs...)
end

@concrete mutable struct JVPScalar
@concrete mutable struct GradientScalar
uf
u
autodiff
end

function Base.:*(jvp::JVPScalar, v::Number)
function Base.:*(jvp::GradientScalar, v::Number)
if jvp.autodiff isa AutoForwardDiff
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v)))
return ForwardDiff.extract_derivative(T, out)
elseif jvp.autodiff isa AutoFiniteDiff
J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
return J * v
return J

Check warning on line 252 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L252

Added line #L252 was not covered by tests
else
error("Only ForwardDiff & FiniteDiff is currently supported.")
end
Expand Down
6 changes: 3 additions & 3 deletions src/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ end

## Arguments

- `freq`: Sets both `print_frequency` and `store_frequency` to `freq`.
- `freq`: Sets both `print_frequency` and `store_frequency` to `freq`.

## Keyword Arguments

- `print_frequency`: Print the trace every `print_frequency` iterations if
- `print_frequency`: Print the trace every `print_frequency` iterations if
`show_trace == Val(true)`.
- `store_frequency`: Store the trace every `store_frequency` iterations if
- `store_frequency`: Store the trace every `store_frequency` iterations if
`store_trace == Val(true)`.
"""
@kwdef struct TraceAll <: AbstractNonlinearSolveTraceLevel
Expand Down
22 changes: 12 additions & 10 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,14 @@
p3
p4
ϵ
jvp_operator # For Yuan
vjp_operator # For Yuan
stats::NLStats
tc_cache
trace
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...;
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg_::TrustRegion, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
Expand Down Expand Up @@ -317,7 +318,7 @@
p3 = convert(floatType, 0.0)
p4 = convert(floatType, 0.0)
ϵ = convert(floatType, 1.0e-8)
jvp_operator = nothing
vjp_operator = nothing
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
p1 = convert(floatType, 0.5)
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
Expand All @@ -336,8 +337,9 @@
p1 = convert(floatType, 2.0) # μ
p2 = convert(floatType, 1 / 6) # c5
p3 = convert(floatType, 6.0) # c6
jvp_operator = __jacvec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
@bb Jᵀf = jvp_operator × fu
vjp_operator = __gradient_operator(uf, u; fu,
autodiff = __get_nonsparse_ad(alg.vjp_autodiff))
@bb Jᵀf = vjp_operator × fu
initial_trust_radius = convert(trustType, p1 * internalnorm(Jᵀf))
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
step_threshold = convert(trustType, 0.0001)
Expand Down Expand Up @@ -366,7 +368,7 @@
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, jvp_operator,
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, vjp_operator,
NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

Expand Down Expand Up @@ -479,7 +481,7 @@
cache.shrink_counter = 0
end

@bb cache.Jᵀf = cache.jvp_operator × vec(cache.fu)
@bb cache.Jᵀf = cache.vjp_operator × vec(cache.fu)
cache.trust_r = cache.p1 * cache.internalnorm(cache.Jᵀf)

cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true)
Expand Down Expand Up @@ -567,10 +569,10 @@

# FIXME: Reinit `JᵀJ` operator if `p` is changed
function __reinit_internal!(cache::TrustRegionCache; kwargs...)
if cache.jvp_operator !== nothing
cache.jvp_operator = __jacvec(cache.uf, cache.u; cache.fu,
if cache.vjp_operator !== nothing
cache.vjp_operator = __gradient_operator(cache.uf, cache.u; cache.fu,

Check warning on line 573 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L573

Added line #L573 was not covered by tests
autodiff = __get_nonsparse_ad(cache.alg.ad))
@bb cache.Jᵀf = cache.jvp_operator × cache.fu
@bb cache.Jᵀf = cache.vjp_operator × cache.fu

Check warning on line 575 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L575

Added line #L575 was not covered by tests
end
cache.loss = __trust_region_loss(cache, cache.fu)
cache.loss_new = cache.loss
Expand Down
6 changes: 6 additions & 0 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]

solvers = []
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()]
vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] :
Expand All @@ -46,6 +47,11 @@ append!(solvers,
LeastSquaresOptimJL(:dogleg),
nothing,
])
for radius_update_scheme in [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright,
RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
push!(solvers, TrustRegion(; radius_update_scheme))
end

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down