diff --git a/Project.toml b/Project.toml index c884503cc..8aacc9aff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.1.0" +version = "3.1.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/NonlinearSolveMINPACKExt.jl b/ext/NonlinearSolveMINPACKExt.jl index a205bdb0d..8d76e726e 100644 --- a/ext/NonlinearSolveMINPACKExt.jl +++ b/ext/NonlinearSolveMINPACKExt.jl @@ -1,13 +1,14 @@ module NonlinearSolveMINPACKExt -using NonlinearSolve, SciMLBase +using NonlinearSolve, DiffEqBase, SciMLBase using MINPACK function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...; abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) where {uType, iip} - @assert termination_condition===nothing "CMINPACK does not support termination conditions!" + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!" if prob.u0 isa Number u0 = [prob.u0] @@ -57,11 +58,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, return Cint(0) end end - original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method, - iterations = maxiters, kwargs...) + original = MINPACK.fsolve(f!, g!, _vec(u0), m; tol = abstol, show_trace, tracing, + method, iterations = maxiters, kwargs...) else - original = MINPACK.fsolve(f!, u0, m; tol = abstol, show_trace, tracing, method, - iterations = maxiters, kwargs...) + original = MINPACK.fsolve(f!, _vec(u0), m; tol = abstol, show_trace, tracing, + method, iterations = maxiters, kwargs...) end u = reshape(original.x, size(u)) @@ -69,10 +70,14 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, # retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure # MINPACK lies about convergence? or maybe uses some other criteria? # We just check for absolute tolerance on the residual - objective = NonlinearSolve.DEFAULT_NORM(resid) + objective = maximum(abs, resid) retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure) - return SciMLBase.build_solution(prob, alg, u, resid; retcode, original) + # These are only meaningful if `tracing = true` + stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls, + original.trace.g_calls, original.trace.g_calls, -1) + + return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original) end end diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index 3651c33dc..a216c01f7 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -5,7 +5,8 @@ import UnPack: @unpack function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6, maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) - @assert termination_condition===nothing "NLsolveJL does not support termination conditions!" + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!" if typeof(prob.u0) <: Number u0 = [prob.u0]