Skip to content

Commit

Permalink
Merge pull request #92 from utkarsh530/u/nlgpufix
Browse files Browse the repository at this point in the history
Change NonlinearSolve enums to SciMLBase enums
  • Loading branch information
ChrisRackauckas authored Oct 25, 2022
2 parents 58f23c2 + 2e26ff2 commit 54c7da2
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache)

if left == mid || right == mid
@set! solver.force_stop = true
@set! solver.retcode = FLOATING_POINT_LIMIT
@set! solver.retcode = ReturnCode.FloatingPointLimit
return solver
end

Expand Down
4 changes: 2 additions & 2 deletions src/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function perform_step(solver, alg::Falsi, cache)

if right == mid || right == mid
@set! solver.force_stop = true
@set! solver.retcode = FLOATING_POINT_LIMIT
@set! solver.retcode = ReturnCode.FloatingPointLimit
return solver
end

Expand All @@ -32,7 +32,7 @@ function perform_step(solver, alg::Falsi, cache)
@set! solver.force_stop = true
@set! solver.left = mid
@set! solver.fl = fm
@set! solver.retcode = EXACT_SOLUTION_LEFT
@set! solver.retcode = ReturnCode.ExactSolutionLeft
else
if sign(fm) == sign(fl)
@set! solver.left = mid
Expand Down
22 changes: 11 additions & 11 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}},
fx)
end
iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT))
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
Δx = dfx \ fx
x -= Δx
if isapprox(x, xo, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT))
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
end
xo = x
end
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(MAXITERS_EXCEED))
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand Down Expand Up @@ -109,7 +109,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(EXACT_SOLUTION_LEFT), left = left,
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end

Expand All @@ -119,7 +119,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(FLOATING_POINT_LIMIT),
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
Expand All @@ -141,7 +141,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(FLOATING_POINT_LIMIT),
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
Expand All @@ -154,7 +154,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
i += 1
end

return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED),
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
end

Expand All @@ -166,7 +166,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(EXACT_SOLUTION_LEFT), left = left,
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end

Expand All @@ -175,7 +175,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
while i < maxiters
if nextfloat_tdir(left, prob.u0...) == right
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(FLOATING_POINT_LIMIT),
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
end
mid = (fr * left - fl * right) / (fr - fl)
Expand Down Expand Up @@ -205,7 +205,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = Symbol(FLOATING_POINT_LIMIT),
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
Expand All @@ -221,6 +221,6 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
i += 1
end

return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED),
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
end
14 changes: 7 additions & 7 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip},
fr = f(right, p)
cache = alg_cache(alg, left, right, p, Val(iip))
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters,
DEFAULT, cache, iip, prob)
ReturnCode.Default, cache, iip, prob)
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm,
Expand All @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewto
end
cache = alg_cache(alg, f, u, p, Val(iip))
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm,
DEFAULT, tol, cache, iip, prob)
Retcode.Default, tol, cache, iip, prob)
end

function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
Expand All @@ -64,14 +64,14 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
@set! solver.iter += 1
end
if solver.iter == solver.maxiters
@set! solver.retcode = MAXITERS_EXCEED
@set! solver.retcode = ReturnCode.MaxIters
end
if typeof(solver) <: NewtonImmutableSolver
SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;
retcode = Symbol(solver.retcode))
retcode = solver.retcode)
else
SciMLBase.build_solution(solver.prob, solver.alg, solver.left, solver.fl;
retcode = Symbol(solver.retcode), left = solver.left,
retcode = solver.retcode, left = solver.left,
right = solver.right)
end
end
Expand All @@ -89,10 +89,10 @@ function mic_check(solver::BracketingImmutableSolver)
(flr > fzero) && error("Non bracketing interval passed in bracketing method.")
if fl == fzero
@set! solver.force_stop = true
@set! solver.retcode = EXACT_SOLUTION_LEFT
@set! solver.retcode = Retcode.ExactSolutionLeft
elseif fr == fzero
@set! solver.force_stop = true
@set! solver.retcode = EXACT_SOLUTION_RIGHT
@set! solver.retcode = Retcode.ExactionSolutionRight
end
solver
end
Expand Down
12 changes: 2 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
@enum Retcode::Int begin
DEFAULT
EXACT_SOLUTION_LEFT
EXACT_SOLUTION_RIGHT
MAXITERS_EXCEED
FLOATING_POINT_LIMIT
end

struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType
} <: AbstractImmutableNonlinearSolver
iter::Int
Expand All @@ -18,7 +10,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
p::pType
force_stop::Bool
maxiters::Int
retcode::Retcode
retcode::SciMLBase.ReturnCode.T
cache::cacheType
iip::Bool
prob::probType
Expand All @@ -40,7 +32,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
force_stop::Bool
maxiters::Int
internalnorm::INType
retcode::Retcode
retcode::SciMLBase.ReturnCode.T
tol::tolType
cache::cacheType
iip::Bool
Expand Down
6 changes: 3 additions & 3 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ end
const csu0 = 1.0

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.retcode === ReturnCode.Default
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.retcode === ReturnCode.Default
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.retcode === ReturnCode.Default
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable(ff, cu0)) == 0
Expand Down

0 comments on commit 54c7da2

Please sign in to comment.