Skip to content

Commit

Permalink
Use a different termination norm for NLLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 31, 2024
1 parent e231d64 commit 63bd4a9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.4"
version = "3.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -63,7 +63,7 @@ BandedMatrices = "1.4"
BenchmarkTools = "1.4"
ConcreteStructs = "0.2.3"
CUDA = "5.1"
DiffEqBase = "6.146.0"
DiffEqBase = "6.149.0"
Enzyme = "0.11.15"
FastBroadcast = "0.2.8"
FastClosures = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ function SciMLBase.__init(
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)

abstol, reltol, termination_cache = init_termination_cache(
abstol, reltol, fu, u, termination_condition)
prob, abstol, reltol, fu, u, termination_condition)
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

J = initialization_cache(nothing)
Expand Down
2 changes: 1 addition & 1 deletion src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function SciMLBase.__init(
linsolve = get_linear_solver(alg.descent)

abstol, reltol, termination_cache = init_termination_cache(
abstol, reltol, fu, u, termination_condition)
prob, abstol, reltol, fu, u, termination_condition)
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

jac_cache = JacobianCache(
Expand Down
2 changes: 1 addition & 1 deletion src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
abstol, reltol, fu, u_cache, termination_condition)
prob, abstol, reltol, fu, u_cache, termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)

if alg.σ_1 === nothing
Expand Down
24 changes: 19 additions & 5 deletions src/internal/termination.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
function init_termination_cache(abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
abstol, reltol, du, u, AbsSafeBestTerminationMode(; max_stalled_steps = 32))
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(prob, abstol, reltol, du, u,
AbsSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32))
end
function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
tc_cache = init(du, u, tc; abstol, reltol, use_deprecated_retcodes = Val(false))
function init_termination_cache(
prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(prob, abstol, reltol, du, u,
AbsSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
internalnorm = ifelse(
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
else
tc
end
tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false))
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
end

Expand Down

0 comments on commit 63bd4a9

Please sign in to comment.