diff --git a/src/scalar.jl b/src/scalar.jl index 9b08d1b20..5b3c25966 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,6 +1,7 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) + fx = float(prob.u0) T = typeof(x) atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5) rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5) @@ -13,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a fx = f(x) dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx) end - iszero(fx) && return NewtonSolution(x, DEFAULT) + iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) Δx = dfx \ fx x -= Δx if isapprox(x, xo, atol=atol, rtol=rtol) - return NewtonSolution(x, DEFAULT) + return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) end xo = x end - return NewtonSolution(x, MAXITERS_EXCEED) + return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED)) end function scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -32,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...) newprob = NonlinearProblem(f, u0, p; prob.kwargs...) sol = solve(newprob, alg, args...; kwargs...) - uu = getsolution(sol) + uu = sol.u if p isa Number f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) else @@ -50,39 +51,42 @@ end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) + end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) end # avoid ambiguities for Alg in [Bisection] @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end end -function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) end i = 1 if !iszero(fr) while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -101,7 +105,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -113,23 +117,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) end -function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) end i = 1 if !iszero(fr) while i < maxiters if nextfloat_tdir(left, prob.u0...) == right - return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) end mid = (fr * left - fl * right) / (fr - fl) for i in 1:10 @@ -156,7 +160,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -171,5 +175,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) end diff --git a/src/solve.jl b/src/solve.jl index 7a6f72459..5235056bf 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -3,7 +3,6 @@ function SciMLBase.solve(prob::NonlinearProblem, kwargs...) solver = init(prob, alg, args...; kwargs...) sol = solve!(solver) - return sol end function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; @@ -30,7 +29,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracket fl = f(left, p) fr = f(right, p) cache = alg_cache(alg, left, right,p, Val(iip)) - return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip) + return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob) end function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...; @@ -55,7 +54,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonA fu = f(u, p) end cache = alg_cache(alg, f, u, p, Val(iip)) - return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip) + return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob) end function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @@ -67,8 +66,11 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) if solver.iter == solver.maxiters @set! solver.retcode = MAXITERS_EXCEED end - sol = get_solution(solver) - return sol + if typeof(solver) <: NewtonImmutableSolver + SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;retcode=Symbol(solver.retcode)) + else + SciMLBase.build_solution(solver.prob, solver.alg, solver.left,solver.fl;retcode=Symbol(solver.retcode),left = solver.left,right = solver.right) + end end """ @@ -96,20 +98,6 @@ function mic_check(solver::NewtonImmutableSolver) solver end -""" - get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver}) - get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver}) - -Form solution object from solver types -""" -function get_solution(solver::BracketingImmutableSolver) - return BracketingSolution(solver.left, solver.right, solver.retcode) -end - -function get_solution(solver::NewtonImmutableSolver) - return NewtonSolution(solver.u, solver.retcode) -end - """ reinit!(solver, prob) diff --git a/src/types.jl b/src/types.jl index e7d38956e..9e673188a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -6,7 +6,7 @@ FLOATING_POINT_LIMIT end -struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractImmutableNonlinearSolver +struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -20,6 +20,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp retcode::Retcode cache::cacheType iip::Bool + prob::probType end # function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) @@ -27,7 +28,7 @@ end # typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) # end -struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType} <: AbstractImmutableNonlinearSolver +struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -41,6 +42,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT tol::tolType cache::cacheType iip::Bool + prob::probType end # function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip @@ -48,22 +50,9 @@ end # typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) # end -struct BracketingSolution{uType} - left::uType - right::uType - retcode::Retcode -end - -struct NewtonSolution{uType} - u::uType - retcode::Retcode -end function sync_residuals!(solver::BracketingImmutableSolver) @set! solver.fl = solver.f(solver.left, solver.p) @set! solver.fr = solver.f(solver.right, solver.p) solver -end - -getsolution(sol::NewtonSolution) = sol.u -getsolution(sol::BracketingSolution) = sol.left +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index df2ae74b9..4c0ec9d2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -101,21 +101,6 @@ function num_types_in_tuple(sig::UnionAll) length(Base.unwrap_unionall(sig).parameters) end -function numargs(f) - typ = Tuple{Any, Val{:analytic}, Vararg} - typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types - typ3 = Tuple{Any, Val{:jac}, Vararg} - typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types - typ5 = Tuple{Any, Val{:tgrad}, Vararg} - typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types - numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)]) - return (numparam-1) #-1 in v0.5 since it adds f as the first parameter -end - -function isinplace(f,inplace_param_number) - numargs(f)>=inplace_param_number -end - ### Default Linsolve # Try to be as smart as possible diff --git a/test/runtests.jl b/test/runtests.jl index 8ab368153..c021c8a86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,13 +23,13 @@ end f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0] sf, su0 = (u,p) -> u * u - 2, 1.0 sol = benchmark_immutable(f, u0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_mutable(f, u0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_scalar(sf, su0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test sol.u * sol.u - 2 < 1e-9 @test (@ballocated benchmark_immutable($f, $u0)) == 0 @@ -117,6 +117,7 @@ probN = NonlinearProblem(f, u0) @test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] ≈ sqrt(2.0) for u0 in [1.0, [1, 1.0]] + local f, probN, sol f = (u, p) -> u .* u .- 2.0 probN = NonlinearProblem(f, u0) sol = sqrt(2) * u0