Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SciML.build_solution #28

Merged
merged 9 commits into from
Feb 3, 2021
41 changes: 23 additions & 18 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
T = typeof(x)
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
Expand All @@ -13,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a
fx = f(x)
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
end
iszero(fx) && return NewtonSolution(x, DEFAULT)
iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
Δx = dfx \ fx
x -= Δx
if isapprox(x, xo, atol=atol, rtol=rtol)
return NewtonSolution(x, DEFAULT)
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
end
xo = x
end
return NewtonSolution(x, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED))
end

function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand All @@ -32,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
sol = solve(newprob, alg, args...; kwargs...)

uu = getsolution(sol)
uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
Expand All @@ -50,39 +51,43 @@ end

function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
#return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
#return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
end

function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)

if iszero(fl)
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
end

i = 1
if !iszero(fr)
while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -101,7 +106,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters

while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -113,23 +118,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
i += 1
end

return BracketingSolution(left, right, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
end

function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)

if iszero(fl)
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
end

i = 1
if !iszero(fr)
while i < maxiters
if nextfloat_tdir(left, prob.u0...) == right
return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
end
mid = (fr * left - fl * right) / (fr - fl)
for i in 1:10
Expand All @@ -156,7 +161,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10

while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -171,5 +176,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
i += 1
end

return BracketingSolution(left, right, MAXITERS_EXCEED)
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
end
15 changes: 9 additions & 6 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ function SciMLBase.solve(prob::NonlinearProblem,
kwargs...)
solver = init(prob, alg, args...; kwargs...)
sol = solve!(solver)
return sol
end

function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
Expand All @@ -30,7 +29,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracket
fl = f(left, p)
fr = f(right, p)
cache = alg_cache(alg, left, right,p, Val(iip))
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob)
end

function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
Expand All @@ -55,7 +54,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonA
fu = f(u, p)
end
cache = alg_cache(alg, f, u, p, Val(iip))
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob)
end

function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
Expand All @@ -68,7 +67,11 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
@set! solver.retcode = MAXITERS_EXCEED
end
sol = get_solution(solver)
return sol
if typeof(sol) <: NewtonSolution
SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol), sol.resid;retcode=Symbol(sol.retcode))
else
SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol),sol.resid;retcode=Symbol(sol.retcode),left = sol.left,right = sol.right)
end
end

"""
Expand Down Expand Up @@ -103,11 +106,11 @@ end
Form solution object from solver types
"""
function get_solution(solver::BracketingImmutableSolver)
return BracketingSolution(solver.left, solver.right, solver.retcode)
return BracketingSolution(solver.left, solver.right, solver.retcode, solver.fl)
end

function get_solution(solver::NewtonImmutableSolver)
return NewtonSolution(solver.u, solver.retcode)
return NewtonSolution(solver.u, solver.retcode, solver.fu)
end

"""
Expand Down
12 changes: 8 additions & 4 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
FLOATING_POINT_LIMIT
end

struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractImmutableNonlinearSolver
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver
iter::Int
f::fType
alg::algType
Expand All @@ -20,14 +20,15 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
retcode::Retcode
cache::cacheType
iip::Bool
prob::probType
end

# function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
# typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
# end

struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType} <: AbstractImmutableNonlinearSolver
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver
iter::Int
f::fType
alg::algType
Expand All @@ -41,22 +42,25 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
tol::tolType
cache::cacheType
iip::Bool
prob::probType
end

# function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
# typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache)
# end

struct BracketingSolution{uType}
struct BracketingSolution{uType,resType}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still using the original solution type?

Copy link
Member Author

@utkarsh530 utkarsh530 Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, we don't need it. Refactored the code. Thanks for the suggestion.

left::uType
right::uType
retcode::Retcode
resid::resType
end

struct NewtonSolution{uType}
struct NewtonSolution{uType,resType}
u::uType
retcode::Retcode
resid::resType
end

function sync_residuals!(solver::BracketingImmutableSolver)
Expand Down
15 changes: 0 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,6 @@ function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

function numargs(f)
typ = Tuple{Any, Val{:analytic}, Vararg}
typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types
typ3 = Tuple{Any, Val{:jac}, Vararg}
typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types
typ5 = Tuple{Any, Val{:tgrad}, Vararg}
typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types
numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)])
return (numparam-1) #-1 in v0.5 since it adds f as the first parameter
end

function isinplace(f,inplace_param_number)
numargs(f)>=inplace_param_number
end

### Default Linsolve

# Try to be as smart as possible
Expand Down
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ end
f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0]
sf, su0 = (u,p) -> u * u - 2, 1.0
sol = benchmark_immutable(f, u0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(f, u0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, su0)
@test sol.retcode === NonlinearSolve.DEFAULT
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable($f, $u0)) == 0
Expand Down Expand Up @@ -117,6 +117,7 @@ probN = NonlinearProblem(f, u0)
@test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] ≈ sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0
Expand Down