diff --git a/docs/src/api/nonlinearsolve.md b/docs/src/api/nonlinearsolve.md index 0eebda2b4..ad0fd3e14 100644 --- a/docs/src/api/nonlinearsolve.md +++ b/docs/src/api/nonlinearsolve.md @@ -2,31 +2,38 @@ These are the native solvers of NonlinearSolve.jl. -## Core Nonlinear Solvers +## Nonlinear Solvers ```@docs NewtonRaphson -TrustRegion PseudoTransient DFSane Broyden Klement ``` -## Polyalgorithms +## Nonlinear Least Squares Solvers ```@docs -NonlinearSolvePolyAlgorithm -FastShortcutNonlinearPolyalg -FastShortcutNLLSPolyalg -RobustMultiNewton +GaussNewton ``` -## Nonlinear Least Squares Solvers +## Both Nonlinear & Nonlinear Least Squares Solvers + +These solvers can be used for both nonlinear and nonlinear least squares problems. ```@docs +TrustRegion LevenbergMarquardt -GaussNewton +``` + +## Polyalgorithms + +```@docs +NonlinearSolvePolyAlgorithm +FastShortcutNonlinearPolyalg +FastShortcutNLLSPolyalg +RobustMultiNewton ``` ## Radius Update Schemes for Trust Region (RadiusUpdateSchemes) diff --git a/docs/src/solvers/NonlinearLeastSquaresSolvers.md b/docs/src/solvers/NonlinearLeastSquaresSolvers.md index 7adfd9508..720cdb7f8 100644 --- a/docs/src/solvers/NonlinearLeastSquaresSolvers.md +++ b/docs/src/solvers/NonlinearLeastSquaresSolvers.md @@ -23,6 +23,8 @@ falls back to a more robust algorithm (`LevenbergMarquardt`). handling of sparse matrices via colored automatic differentiation and preconditioned linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares problems. + - `TrustRegion()`: A Newton Trust Region dogleg method with swappable nonlinear solvers and + autodiff methods for high performance on large and sparse systems. ### SimpleNonlinearSolve.jl diff --git a/src/jacobian.jl b/src/jacobian.jl index 30e297dc8..a1870ffd1 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -213,30 +213,43 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) end # jvp fallback scalar -function __jacvec(uf, u; autodiff, kwargs...) - if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff) +function __gradient_operator(uf, u; autodiff, kwargs...) + if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote) _ad = autodiff - autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(), + number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(), AutoFiniteDiff()) - @warn "$(_ad) not supported for JacVec. Using $(autodiff) instead." + if u isa Number + autodiff = number_ad + else + if isinplace(uf) + autodiff = AutoFiniteDiff() + else + autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(), + AutoFiniteDiff()) + end + end + if _ad !== nothing && _ad !== autodiff + @warn "$(_ad) not supported for VecJac. Using $(autodiff) instead." + end end - return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...) + return u isa Number ? GradientScalar(uf, u, autodiff) : + VecJac(uf, u; autodiff, kwargs...) end -@concrete mutable struct JVPScalar +@concrete mutable struct GradientScalar uf u autodiff end -function Base.:*(jvp::JVPScalar, v::Number) +function Base.:*(jvp::GradientScalar, v::Number) if jvp.autodiff isa AutoForwardDiff T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u))) - out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v)) + out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v))) return ForwardDiff.extract_derivative(T, out) elseif jvp.autodiff isa AutoFiniteDiff J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype) - return J * v + return J else error("Only ForwardDiff & FiniteDiff is currently supported.") end diff --git a/src/trace.jl b/src/trace.jl index dcb7564f7..5a7c88342 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -60,13 +60,13 @@ end ## Arguments - - `freq`: Sets both `print_frequency` and `store_frequency` to `freq`. + - `freq`: Sets both `print_frequency` and `store_frequency` to `freq`. ## Keyword Arguments - - `print_frequency`: Print the trace every `print_frequency` iterations if + - `print_frequency`: Print the trace every `print_frequency` iterations if `show_trace == Val(true)`. - - `store_frequency`: Store the trace every `store_frequency` iterations if + - `store_frequency`: Store the trace every `store_frequency` iterations if `store_trace == Val(true)`. """ @kwdef struct TraceAll <: AbstractNonlinearSolveTraceLevel diff --git a/src/trustRegion.jl b/src/trustRegion.jl index e91bf8461..3312cbc63 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -247,13 +247,14 @@ end p3 p4 ϵ - jvp_operator # For Yuan + vjp_operator # For Yuan stats::NLStats tc_cache trace end -function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...; +function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, + NonlinearLeastSquaresProblem{uType, iip}}, alg_::TrustRegion, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @@ -317,7 +318,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, p3 = convert(floatType, 0.0) p4 = convert(floatType, 0.0) ϵ = convert(floatType, 1.0e-8) - jvp_operator = nothing + vjp_operator = nothing if radius_update_scheme === RadiusUpdateSchemes.NLsolve p1 = convert(floatType, 0.5) elseif radius_update_scheme === RadiusUpdateSchemes.Hei @@ -336,8 +337,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, p1 = convert(floatType, 2.0) # μ p2 = convert(floatType, 1 / 6) # c5 p3 = convert(floatType, 6.0) # c6 - jvp_operator = __jacvec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) - @bb Jᵀf = jvp_operator × fu + vjp_operator = __gradient_operator(uf, u; fu, + autodiff = __get_nonsparse_ad(alg.vjp_autodiff)) + @bb Jᵀf = vjp_operator × fu initial_trust_radius = convert(trustType, p1 * internalnorm(Jᵀf)) elseif radius_update_scheme === RadiusUpdateSchemes.Fan step_threshold = convert(trustType, 0.0001) @@ -366,7 +368,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new, - shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, jvp_operator, + shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, vjp_operator, NLStats(1, 0, 0, 0, 0), tc_cache, trace) end @@ -479,7 +481,7 @@ function trust_region_step!(cache::TrustRegionCache) cache.shrink_counter = 0 end - @bb cache.Jᵀf = cache.jvp_operator × vec(cache.fu) + @bb cache.Jᵀf = cache.vjp_operator × vec(cache.fu) cache.trust_r = cache.p1 * cache.internalnorm(cache.Jᵀf) cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true) @@ -567,10 +569,10 @@ end # FIXME: Reinit `JᵀJ` operator if `p` is changed function __reinit_internal!(cache::TrustRegionCache; kwargs...) - if cache.jvp_operator !== nothing - cache.jvp_operator = __jacvec(cache.uf, cache.u; cache.fu, + if cache.vjp_operator !== nothing + cache.vjp_operator = __gradient_operator(cache.uf, cache.u; cache.fu, autodiff = __get_nonsparse_ad(cache.alg.ad)) - @bb cache.Jᵀf = cache.jvp_operator × cache.fu + @bb cache.Jᵀf = cache.vjp_operator × cache.fu end cache.loss = __trust_region_loss(cache, cache.fu) cache.loss_new = cache.loss diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index f4b8233e0..e97da43eb 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -28,6 +28,7 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] + solvers = [] for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()] vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] : @@ -46,6 +47,11 @@ append!(solvers, LeastSquaresOptimJL(:dogleg), nothing, ]) +for radius_update_scheme in [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright, + RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan, + RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin] + push!(solvers, TrustRegion(; radius_update_scheme)) +end for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)