From a737b3eab7092c32e9eee7694b59e2c579123e83 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Thu, 20 Oct 2022 15:04:16 -0400 Subject: [PATCH 1/2] Change NonlinearSolve enums to SciMLBase enums --- src/bisection.jl | 2 +- src/falsi.jl | 4 ++-- src/scalar.jl | 22 +++++++++++----------- src/solve.jl | 14 +++++++------- src/types.jl | 12 ++---------- test/basictests.jl | 6 +++--- 6 files changed, 26 insertions(+), 34 deletions(-) diff --git a/src/bisection.jl b/src/bisection.jl index 8270ba9a3..ad3870760 100644 --- a/src/bisection.jl +++ b/src/bisection.jl @@ -35,7 +35,7 @@ function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache) if left == mid || right == mid @set! solver.force_stop = true - @set! solver.retcode = FLOATING_POINT_LIMIT + @set! solver.retcode = ReturnCode.Success return solver end diff --git a/src/falsi.jl b/src/falsi.jl index f184088cb..a9b9d80d5 100644 --- a/src/falsi.jl +++ b/src/falsi.jl @@ -21,7 +21,7 @@ function perform_step(solver, alg::Falsi, cache) if right == mid || right == mid @set! solver.force_stop = true - @set! solver.retcode = FLOATING_POINT_LIMIT + @set! solver.retcode = ReturnCode.Success return solver end @@ -32,7 +32,7 @@ function perform_step(solver, alg::Falsi, cache) @set! solver.force_stop = true @set! solver.left = mid @set! solver.fl = fm - @set! solver.retcode = EXACT_SOLUTION_LEFT + @set! solver.retcode = ReturnCode.Success else if sign(fm) == sign(fl) @set! solver.left = mid diff --git a/src/scalar.jl b/src/scalar.jl index b81d259bf..afd2b2874 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -26,15 +26,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}}, fx) end iszero(fx) && - return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT)) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default) Δx = dfx \ fx x -= Δx if isapprox(x, xo, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT)) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default) end xo = x end - return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(MAXITERS_EXCEED)) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) end function scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -109,7 +109,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite if iszero(fl) return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(EXACT_SOLUTION_LEFT), left = left, + retcode = ReturnCode.Success, left = left, right = right) end @@ -119,7 +119,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(FLOATING_POINT_LIMIT), + retcode = ReturnCode.Success, left = left, right = right) fm = f(mid) if iszero(fm) @@ -141,7 +141,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(FLOATING_POINT_LIMIT), + retcode = ReturnCode.Success, left = left, right = right) fm = f(mid) if iszero(fm) @@ -154,7 +154,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite i += 1 end - return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED), + return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, left = left, right = right) end @@ -166,7 +166,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = if iszero(fl) return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(EXACT_SOLUTION_LEFT), left = left, + retcode = ReturnCode.Success, left = left, right = right) end @@ -175,7 +175,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = while i < maxiters if nextfloat_tdir(left, prob.u0...) == right return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(FLOATING_POINT_LIMIT), + retcode = ReturnCode.Success, left = left, right = right) end mid = (fr * left - fl * right) / (fr - fl) @@ -205,7 +205,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = Symbol(FLOATING_POINT_LIMIT), + retcode = ReturnCode.Success, left = left, right = right) fm = f(mid) if iszero(fm) @@ -221,6 +221,6 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = i += 1 end - return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED), + return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, left = left, right = right) end diff --git a/src/solve.jl b/src/solve.jl index 702cdd87f..30e1f14ea 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -29,7 +29,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, fr = f(right, p) cache = alg_cache(alg, left, right, p, Val(iip)) return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, - DEFAULT, cache, iip, prob) + ReturnCode.Default, cache, iip, prob) end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewto end cache = alg_cache(alg, f, u, p, Val(iip)) return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, - DEFAULT, tol, cache, iip, prob) + Retcode.Default, tol, cache, iip, prob) end function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @@ -64,14 +64,14 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @set! solver.iter += 1 end if solver.iter == solver.maxiters - @set! solver.retcode = MAXITERS_EXCEED + @set! solver.retcode = ReturnCode.MaxIters end if typeof(solver) <: NewtonImmutableSolver SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu; - retcode = Symbol(solver.retcode)) + retcode = solver.retcode) else SciMLBase.build_solution(solver.prob, solver.alg, solver.left, solver.fl; - retcode = Symbol(solver.retcode), left = solver.left, + retcode = solver.retcode, left = solver.left, right = solver.right) end end @@ -89,10 +89,10 @@ function mic_check(solver::BracketingImmutableSolver) (flr > fzero) && error("Non bracketing interval passed in bracketing method.") if fl == fzero @set! solver.force_stop = true - @set! solver.retcode = EXACT_SOLUTION_LEFT + @set! solver.retcode = Retcode.Success elseif fr == fzero @set! solver.force_stop = true - @set! solver.retcode = EXACT_SOLUTION_RIGHT + @set! solver.retcode = Retcode.Success end solver end diff --git a/src/types.jl b/src/types.jl index c272e09f7..99ecadba2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,11 +1,3 @@ -@enum Retcode::Int begin - DEFAULT - EXACT_SOLUTION_LEFT - EXACT_SOLUTION_RIGHT - MAXITERS_EXCEED - FLOATING_POINT_LIMIT -end - struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType } <: AbstractImmutableNonlinearSolver iter::Int @@ -18,7 +10,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp p::pType force_stop::Bool maxiters::Int - retcode::Retcode + retcode::SciMLBase.ReturnCode.T cache::cacheType iip::Bool prob::probType @@ -40,7 +32,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT force_stop::Bool maxiters::Int internalnorm::INType - retcode::Retcode + retcode::SciMLBase.ReturnCode.T tol::tolType cache::cacheType iip::Bool diff --git a/test/basictests.jl b/test/basictests.jl index e37293b62..98ee8627b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -30,13 +30,13 @@ end const csu0 = 1.0 sol = benchmark_immutable(ff, cu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test sol.retcode === ReturnCode.Default @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_mutable(ff, cu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test sol.retcode === ReturnCode.Default @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_scalar(sf, csu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test sol.retcode === ReturnCode.Default @test sol.u * sol.u - 2 < 1e-9 @test (@ballocated benchmark_immutable(ff, cu0)) == 0 From 2e26ff2f8874494c5144e5b805a6e588bd3a7f8b Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Mon, 24 Oct 2022 14:47:20 -0400 Subject: [PATCH 2/2] Change ReturnCodes from SciMLBase --- src/bisection.jl | 2 +- src/falsi.jl | 4 ++-- src/scalar.jl | 12 ++++++------ src/solve.jl | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/bisection.jl b/src/bisection.jl index ad3870760..7268da9f1 100644 --- a/src/bisection.jl +++ b/src/bisection.jl @@ -35,7 +35,7 @@ function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache) if left == mid || right == mid @set! solver.force_stop = true - @set! solver.retcode = ReturnCode.Success + @set! solver.retcode = ReturnCode.FloatingPointLimit return solver end diff --git a/src/falsi.jl b/src/falsi.jl index a9b9d80d5..f3420cc63 100644 --- a/src/falsi.jl +++ b/src/falsi.jl @@ -21,7 +21,7 @@ function perform_step(solver, alg::Falsi, cache) if right == mid || right == mid @set! solver.force_stop = true - @set! solver.retcode = ReturnCode.Success + @set! solver.retcode = ReturnCode.FloatingPointLimit return solver end @@ -32,7 +32,7 @@ function perform_step(solver, alg::Falsi, cache) @set! solver.force_stop = true @set! solver.left = mid @set! solver.fl = fm - @set! solver.retcode = ReturnCode.Success + @set! solver.retcode = ReturnCode.ExactSolutionLeft else if sign(fm) == sign(fl) @set! solver.left = mid diff --git a/src/scalar.jl b/src/scalar.jl index afd2b2874..07f2c392b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -109,7 +109,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite if iszero(fl) return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, left = left, + retcode = ReturnCode.ExactSolutionLeft, left = left, right = right) end @@ -119,7 +119,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, + retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) if iszero(fm) @@ -141,7 +141,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, + retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) if iszero(fm) @@ -166,7 +166,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = if iszero(fl) return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, left = left, + retcode = ReturnCode.ExactSolutionLeft, left = left, right = right) end @@ -175,7 +175,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = while i < maxiters if nextfloat_tdir(left, prob.u0...) == right return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, + retcode = ReturnCode.FloatingPointLimit, left = left, right = right) end mid = (fr * left - fl * right) / (fr - fl) @@ -205,7 +205,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = mid = (left + right) / 2 (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.Success, + retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) if iszero(fm) diff --git a/src/solve.jl b/src/solve.jl index 30e1f14ea..930ce9c7e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -89,10 +89,10 @@ function mic_check(solver::BracketingImmutableSolver) (flr > fzero) && error("Non bracketing interval passed in bracketing method.") if fl == fzero @set! solver.force_stop = true - @set! solver.retcode = Retcode.Success + @set! solver.retcode = Retcode.ExactSolutionLeft elseif fr == fzero @set! solver.force_stop = true - @set! solver.retcode = Retcode.Success + @set! solver.retcode = Retcode.ExactionSolutionRight end solver end