Skip to content

Commit

Permalink
Merge pull request #25 from JuliaComputing/scimlbase
Browse files Browse the repository at this point in the history
Extend SciMLBase
  • Loading branch information
ChrisRackauckas authored Jan 27, 2021
2 parents 3398cc4 + 33fa361 commit 71162f9
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 38 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand Down
8 changes: 2 additions & 6 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using Setfield
using StaticArrays
using RecursiveArrayTools

@reexport using SciMLBase

abstract type AbstractNonlinearProblem{uType,isinplace} end
abstract type AbstractNonlinearSolveAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
Expand All @@ -27,10 +29,4 @@ include("scalar.jl")
# DiffEq styled algorithms
export Bisection, Falsi, NewtonRaphson

export NonlinearProblem

export solve, init, solve!

export reinit!

end # module
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD}
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD}
diff_type::DT
linsolve::L
end
Expand Down
14 changes: 7 additions & 7 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
T = typeof(x)
Expand Down Expand Up @@ -48,28 +48,28 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
end
function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
end

# avoid ambiguities
for Alg in [Bisection]
@eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
end
@eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
end
end

function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)
Expand Down Expand Up @@ -116,7 +116,7 @@ function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kw
return BracketingSolution(left, right, MAXITERS_EXCEED)
end

function solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)
Expand Down
16 changes: 8 additions & 8 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function solve(prob::NonlinearProblem,
alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...)
function SciMLBase.solve(prob::NonlinearProblem,
alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...)
solver = init(prob, alg, args...; kwargs...)
sol = solve!(solver)
return sol
end

function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
alias_u0 = false,
maxiters = 1000,
kwargs...
Expand All @@ -33,7 +33,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorit
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
end

function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
alias_u0 = false,
maxiters = 1000,
tol = 1e-6,
Expand All @@ -58,7 +58,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm,
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
end

function solve!(solver::AbstractImmutableNonlinearSolver)
function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
solver = mic_check(solver)
while !solver.force_stop && solver.iter < solver.maxiters
solver = perform_step(solver, solver.alg, Val(solver.iip))
Expand Down Expand Up @@ -115,14 +115,14 @@ end
Reinitialize solver to the original starting conditions
"""
function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType}
function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType}
@. solver.u = prob.u0
@set! solver.iter = 1
@set! solver.force_stop = false
return solver
end

function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType}
function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType}
@set! solver.u = prob.u0
@set! solver.iter = 1
@set! solver.force_stop = false
Expand Down
18 changes: 2 additions & 16 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
struct NullParameters end

struct NonlinearProblem{uType,isinplace,P,F,K} <: AbstractNonlinearProblem{uType,isinplace}
f::F
u0::uType
p::P
kwargs::K
@add_kwonly function NonlinearProblem{iip}(f,u0,p=NullParameters();kwargs...) where iip
new{typeof(u0),iip,typeof(p),typeof(f),typeof(kwargs)}(f,u0,p,kwargs)
end
end

NonlinearProblem(f,u0,args...;kwargs...) = NonlinearProblem{isinplace(f, 3)}(f,u0,args...;kwargs...)

@enum Retcode::Int begin
DEFAULT
EXACT_SOLUTION_LEFT
Expand All @@ -37,7 +23,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
end

# function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
# typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# end

Expand All @@ -58,7 +44,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
end

# function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
# typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache)
# end

Expand Down

0 comments on commit 71162f9

Please sign in to comment.