From 4562e108fbc72215fa3886fb1607e0ccf0fa82ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 31 Mar 2024 00:02:38 -0400 Subject: [PATCH] Use a different termination norm for NLLS --- Project.toml | 4 ++-- src/core/approximate_jacobian.jl | 2 +- src/core/generalized_first_order.jl | 2 +- src/core/spectral_methods.jl | 2 +- src/internal/termination.jl | 24 +++++++++++++++++++----- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 31d3b3e54..9411cc234 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.8.4" +version = "3.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -63,7 +63,7 @@ BandedMatrices = "1.4" BenchmarkTools = "1.4" ConcreteStructs = "0.2.3" CUDA = "5.1" -DiffEqBase = "6.146.0" +DiffEqBase = "6.149.0" Enzyme = "0.11.15" FastBroadcast = "0.2.8" FastClosures = "0.3" diff --git a/src/core/approximate_jacobian.jl b/src/core/approximate_jacobian.jl index 4204b2db7..afdd445bb 100644 --- a/src/core/approximate_jacobian.jl +++ b/src/core/approximate_jacobian.jl @@ -167,7 +167,7 @@ function SciMLBase.__init( prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm) abstol, reltol, termination_cache = init_termination_cache( - abstol, reltol, fu, u, termination_condition) + prob, abstol, reltol, fu, u, termination_condition) linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs) J = initialization_cache(nothing) diff --git a/src/core/generalized_first_order.jl b/src/core/generalized_first_order.jl index 753b7682f..883609006 100644 --- a/src/core/generalized_first_order.jl +++ b/src/core/generalized_first_order.jl @@ -156,7 +156,7 @@ function SciMLBase.__init( linsolve = get_linear_solver(alg.descent) abstol, reltol, termination_cache = init_termination_cache( - abstol, reltol, fu, u, termination_condition) + prob, abstol, reltol, fu, u, termination_condition) linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs) jac_cache = JacobianCache( diff --git a/src/core/spectral_methods.jl b/src/core/spectral_methods.jl index a4e824744..a2966c765 100644 --- a/src/core/spectral_methods.jl +++ b/src/core/spectral_methods.jl @@ -133,7 +133,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...) abstol, reltol, tc_cache = init_termination_cache( - abstol, reltol, fu, u_cache, termination_condition) + prob, abstol, reltol, fu, u_cache, termination_condition) trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...) if alg.σ_1 === nothing diff --git a/src/internal/termination.jl b/src/internal/termination.jl index e55e344f7..4f17cd593 100644 --- a/src/internal/termination.jl +++ b/src/internal/termination.jl @@ -1,9 +1,23 @@ -function init_termination_cache(abstol, reltol, du, u, ::Nothing) - return init_termination_cache( - abstol, reltol, du, u, AbsSafeBestTerminationMode(; max_stalled_steps = 32)) +function init_termination_cache(::NonlinearProblem, abstol, reltol, du, u, ::Nothing) + return init_termination_cache(prob, abstol, reltol, du, u, + AbsSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32)) end -function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) - tc_cache = init(du, u, tc; abstol, reltol, use_deprecated_retcodes = Val(false)) +function init_termination_cache( + ::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing) + return init_termination_cache(prob, abstol, reltol, du, u, + AbsSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32)) +end + +function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) + tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing + internalnorm = ifelse( + prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2)) + DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm) + else + tc + end + tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false)) return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache end