From 33fa361eb073698414dfef54729200d46bcd3248 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 25 Jan 2021 23:17:38 -0500 Subject: [PATCH] Extend SciMLBase --- Project.toml | 1 + src/NonlinearSolve.jl | 8 ++------ src/raphson.jl | 2 +- src/scalar.jl | 14 +++++++------- src/solve.jl | 16 ++++++++-------- src/types.jl | 18 ++---------------- 6 files changed, 21 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index b0d38619b..6aa20eadb 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index ed5d0978f..c9247053e 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,6 +8,8 @@ using Setfield using StaticArrays using RecursiveArrayTools +@reexport using SciMLBase + abstract type AbstractNonlinearProblem{uType,isinplace} end abstract type AbstractNonlinearSolveAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end @@ -27,10 +29,4 @@ include("scalar.jl") # DiffEq styled algorithms export Bisection, Falsi, NewtonRaphson -export NonlinearProblem - -export solve, init, solve! - -export reinit! - end # module diff --git a/src/raphson.jl b/src/raphson.jl index 27cb5a630..2fe548f1b 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,4 +1,4 @@ -struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD} +struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD} diff_type::DT linsolve::L end diff --git a/src/scalar.jl b/src/scalar.jl index 66ab6d4b5..9b08d1b20 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,4 +1,4 @@ -function solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) T = typeof(x) @@ -48,28 +48,28 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...) return sol, partials end -function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} +function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) end -function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} +function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) end # avoid ambiguities for Alg in [Bisection] - @eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} + @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) end - @eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} + @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) end end -function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) @@ -116,7 +116,7 @@ function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kw return BracketingSolution(left, right, MAXITERS_EXCEED) end -function solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) diff --git a/src/solve.jl b/src/solve.jl index ad99e1f34..7a6f72459 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,12 +1,12 @@ -function solve(prob::NonlinearProblem, - alg::AbstractNonlinearSolveAlgorithm, args...; - kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, + alg::AbstractNonlinearSolveAlgorithm, args...; + kwargs...) solver = init(prob, alg, args...; kwargs...) sol = solve!(solver) return sol end -function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; +function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; alias_u0 = false, maxiters = 1000, kwargs... @@ -33,7 +33,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorit return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip) end -function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...; +function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...; alias_u0 = false, maxiters = 1000, tol = 1e-6, @@ -58,7 +58,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip) end -function solve!(solver::AbstractImmutableNonlinearSolver) +function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) solver = mic_check(solver) while !solver.force_stop && solver.iter < solver.maxiters solver = perform_step(solver, solver.alg, Val(solver.iip)) @@ -115,14 +115,14 @@ end Reinitialize solver to the original starting conditions """ -function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType} +function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType} @. solver.u = prob.u0 @set! solver.iter = 1 @set! solver.force_stop = false return solver end -function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType} +function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType} @set! solver.u = prob.u0 @set! solver.iter = 1 @set! solver.force_stop = false diff --git a/src/types.jl b/src/types.jl index 9f1b03ee9..e7d38956e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,17 +1,3 @@ -struct NullParameters end - -struct NonlinearProblem{uType,isinplace,P,F,K} <: AbstractNonlinearProblem{uType,isinplace} - f::F - u0::uType - p::P - kwargs::K - @add_kwonly function NonlinearProblem{iip}(f,u0,p=NullParameters();kwargs...) where iip - new{typeof(u0),iip,typeof(p),typeof(f),typeof(kwargs)}(f,u0,p,kwargs) - end -end - -NonlinearProblem(f,u0,args...;kwargs...) = NonlinearProblem{isinplace(f, 3)}(f,u0,args...;kwargs...) - @enum Retcode::Int begin DEFAULT EXACT_SOLUTION_LEFT @@ -37,7 +23,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp end # function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) -# BracketingImmutableSolver{iip, typeof(f), typeof(alg), +# BracketingImmutableSolver{iip, typeof(f), typeof(alg), # typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) # end @@ -58,7 +44,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT end # function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip -# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u), +# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u), # typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) # end