Skip to content

Commit

Permalink
Use dispatch instead of adding Scalar* solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Sep 21, 2020
1 parent b8e36e7 commit dc869a9
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 20 deletions.
1 change: 0 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ module NonlinearSolve

# DiffEq styled algorithms
export Bisection, Falsi, NewtonRaphson
export ScalarBisection, ScalarNewton

export reinit!
end # module
19 changes: 2 additions & 17 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""
ScalarNewton
Fast Newton Raphson for scalar problems.
"""
struct ScalarNewton <: AbstractNonlinearSolveAlgorithm end

function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarNewton, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) where {uType}
function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
T = typeof(x)
Expand All @@ -26,15 +19,7 @@ function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarNewton,
return oftype(x, NaN)
end

"""
ScalarBisection
Fast Bisection for scalar problems. Note that it doesn't returns exact solution, but returns
the best left limit of the exact solution.
"""
struct ScalarBisection <: AbstractNonlinearSolveAlgorithm end

function DiffEqBase.solve(prob::NonlinearProblem{uType, false}, ::ScalarBisection, args...; maxiters = 1000, kwargs...) where {uType}
function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::Bisection, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
immutable = (prob.u0 isa StaticArray || prob.u0 isa Number),
kwargs...
) where {uType, iip}

if !(prob.u0 isa Tuple)
error("You need to pass a tuple of u0 in bracketing algorithms.")
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, ScalarNewton()))
sol = (solve(probN, NewtonRaphson()))
end

f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0]
Expand Down

0 comments on commit dc869a9

Please sign in to comment.