From 1530326c68a971ca3d967eece12d8e6e8a617f7b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Dec 2023 13:19:46 -0500 Subject: [PATCH] Propagate stats from MINPACK --- Project.toml | 2 +- ext/NonlinearSolveMINPACKExt.jl | 21 +++++++++++++-------- ext/NonlinearSolveNLsolveExt.jl | 16 +++++++++------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index c884503cc..8aacc9aff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.1.0" +version = "3.1.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/NonlinearSolveMINPACKExt.jl b/ext/NonlinearSolveMINPACKExt.jl index a205bdb0d..b86d78199 100644 --- a/ext/NonlinearSolveMINPACKExt.jl +++ b/ext/NonlinearSolveMINPACKExt.jl @@ -1,13 +1,14 @@ module NonlinearSolveMINPACKExt -using NonlinearSolve, SciMLBase +using NonlinearSolve, DiffEqBase, SciMLBase using MINPACK function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...; abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) where {uType, iip} - @assert termination_condition===nothing "CMINPACK does not support termination conditions!" + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!" if prob.u0 isa Number u0 = [prob.u0] @@ -57,11 +58,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, return Cint(0) end end - original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method, - iterations = maxiters, kwargs...) + original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing, + method, iterations = maxiters, kwargs...) else - original = MINPACK.fsolve(f!, u0, m; tol = abstol, show_trace, tracing, method, - iterations = maxiters, kwargs...) + original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing, + method, iterations = maxiters, kwargs...) end u = reshape(original.x, size(u)) @@ -69,10 +70,14 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, # retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure # MINPACK lies about convergence? or maybe uses some other criteria? # We just check for absolute tolerance on the residual - objective = NonlinearSolve.DEFAULT_NORM(resid) + objective = maximum(abs, resid) retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure) - return SciMLBase.build_solution(prob, alg, u, resid; retcode, original) + # These are only meaningful if `tracing = true` + stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls, + original.trace.g_calls, original.trace.g_calls, -1) + + return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original) end end diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index 3651c33dc..1b8d7e3f1 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -5,7 +5,8 @@ import UnPack: @unpack function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6, maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) - @assert termination_condition===nothing "NLsolveJL does not support termination conditions!" + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!" if typeof(prob.u0) <: Number u0 = [prob.u0] @@ -59,19 +60,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst end if prob.f.jac_prototype !== nothing J = zero(prob.f.jac_prototype) - df = OnceDifferentiable(f!, g!, u0, resid, J) + df = OnceDifferentiable(f!, g!, vec(u0), vec(resid), J) else - df = OnceDifferentiable(f!, g!, u0, resid) + df = OnceDifferentiable(f!, g!, vec(u0), vec(resid)) end else - df = OnceDifferentiable(f!, u0, resid; autodiff) + df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff) end - original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace, - extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) + original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method, + store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, + show_trace) u = reshape(original.zero, size(u0)) - f!(resid, u) + f!(vec(resid), vec(u)) retcode = original.x_converged || original.f_converged ? ReturnCode.Success : ReturnCode.Failure stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,