Skip to content

Commit

Permalink
Propagate stats from MINPACK
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 17, 2023
1 parent 4602f34 commit 1530326
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.1.0"
version = "3.1.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
21 changes: 13 additions & 8 deletions ext/NonlinearSolveMINPACKExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module NonlinearSolveMINPACKExt

using NonlinearSolve, SciMLBase
using NonlinearSolve, DiffEqBase, SciMLBase
using MINPACK

function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...) where {uType, iip}
@assert termination_condition===nothing "CMINPACK does not support termination conditions!"
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"

if prob.u0 isa Number
u0 = [prob.u0]
Expand Down Expand Up @@ -57,22 +58,26 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
return Cint(0)
end
end
original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method,
iterations = maxiters, kwargs...)
original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing,
method, iterations = maxiters, kwargs...)
else
original = MINPACK.fsolve(f!, u0, m; tol = abstol, show_trace, tracing, method,
iterations = maxiters, kwargs...)
original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing,
method, iterations = maxiters, kwargs...)
end

u = reshape(original.x, size(u))
resid = original.f
# retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
# MINPACK lies about convergence? or maybe uses some other criteria?
# We just check for absolute tolerance on the residual
objective = NonlinearSolve.DEFAULT_NORM(resid)
objective = maximum(abs, resid)
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)

return SciMLBase.build_solution(prob, alg, u, resid; retcode, original)
# These are only meaningful if `tracing = true`
stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls,
original.trace.g_calls, original.trace.g_calls, -1)

return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
end

end
16 changes: 9 additions & 7 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import UnPack: @unpack

function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
@assert termination_condition===nothing "NLsolveJL does not support termination conditions!"
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"

if typeof(prob.u0) <: Number
u0 = [prob.u0]
Expand Down Expand Up @@ -59,19 +60,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
end
if prob.f.jac_prototype !== nothing
J = zero(prob.f.jac_prototype)
df = OnceDifferentiable(f!, g!, u0, resid, J)
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid), J)

Check warning on line 63 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L63

Added line #L63 was not covered by tests
else
df = OnceDifferentiable(f!, g!, u0, resid)
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid))
end
else
df = OnceDifferentiable(f!, u0, resid; autodiff)
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
end

original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace,
extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace)
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
show_trace)

u = reshape(original.zero, size(u0))
f!(resid, u)
f!(vec(resid), vec(u))
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
ReturnCode.Failure
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
Expand Down

0 comments on commit 1530326

Please sign in to comment.