From 949d46e62222c48e5b71caf827793c6dff8ff5e7 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Mon, 1 Feb 2021 02:43:11 +0530 Subject: [PATCH 1/9] Remove local isinplace after SciMLBase import --- src/utils.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index df2ae74b9..4c0ec9d2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -101,21 +101,6 @@ function num_types_in_tuple(sig::UnionAll) length(Base.unwrap_unionall(sig).parameters) end -function numargs(f) - typ = Tuple{Any, Val{:analytic}, Vararg} - typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types - typ3 = Tuple{Any, Val{:jac}, Vararg} - typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types - typ5 = Tuple{Any, Val{:tgrad}, Vararg} - typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types - numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)]) - return (numparam-1) #-1 in v0.5 since it adds f as the first parameter -end - -function isinplace(f,inplace_param_number) - numargs(f)>=inplace_param_number -end - ### Default Linsolve # Try to be as smart as possible From d5ed5970ffff6dc9d0d686a0ef0acabe68203437 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Mon, 1 Feb 2021 02:43:45 +0530 Subject: [PATCH 2/9] Add build_solution for Newton Raphson --- src/solve.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 7a6f72459..0b4340a4f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -3,7 +3,15 @@ function SciMLBase.solve(prob::NonlinearProblem, kwargs...) solver = init(prob, alg, args...; kwargs...) sol = solve!(solver) - return sol + if typeof(sol) <: NewtonSolution + resid = zero(prob.u0) + if isinplace(prob) + prob.f(resid,sol.u,prob.p) + else + resid = prob.f(sol.u,prob.p) + end + return SciMLBase.build_solution(prob, alg, sol.u, resid;retcode=:Success) + end end function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; From df1fd74e480703fdaf89b2686b2192ec458f040a Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Mon, 1 Feb 2021 22:41:38 +0530 Subject: [PATCH 3/9] pass resid from solve! --- src/solve.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 0b4340a4f..2dca2fa7f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -2,14 +2,8 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) solver = init(prob, alg, args...; kwargs...) - sol = solve!(solver) + sol, resid = solve!(solver) if typeof(sol) <: NewtonSolution - resid = zero(prob.u0) - if isinplace(prob) - prob.f(resid,sol.u,prob.p) - else - resid = prob.f(sol.u,prob.p) - end return SciMLBase.build_solution(prob, alg, sol.u, resid;retcode=:Success) end end @@ -76,7 +70,7 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @set! solver.retcode = MAXITERS_EXCEED end sol = get_solution(solver) - return sol + return sol, solver.fu end """ From 99ebebc5a17f5c844c85b6d48a8c7e27f911e22b Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Tue, 2 Feb 2021 23:58:28 +0530 Subject: [PATCH 4/9] Add resid param in struct and update build_solution in solve.jl --- src/scalar.jl | 31 ++++++++++++++++--------------- src/solve.jl | 12 +++++++----- src/types.jl | 6 ++++-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 9b08d1b20..8f32ce5dd 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,6 +1,7 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) + fx = float(prob.u0) T = typeof(x) atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5) rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5) @@ -13,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a fx = f(x) dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx) end - iszero(fx) && return NewtonSolution(x, DEFAULT) + iszero(fx) && return NewtonSolution(x, DEFAULT, fx) Δx = dfx \ fx x -= Δx if isapprox(x, xo, atol=atol, rtol=rtol) - return NewtonSolution(x, DEFAULT) + return NewtonSolution(x, DEFAULT, fx) end xo = x end - return NewtonSolution(x, MAXITERS_EXCEED) + return NewtonSolution(x, MAXITERS_EXCEED, fx) end function scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -50,22 +51,22 @@ end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) + return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) + return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) end # avoid ambiguities for Alg in [Bisection] @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end end @@ -75,14 +76,14 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT) + return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl) end i = 1 if !iszero(fr) while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) fm = f(mid) if iszero(fm) right = mid @@ -101,7 +102,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) fm = f(mid) if iszero(fm) right = mid @@ -113,7 +114,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED) + return BracketingSolution(left, right, MAXITERS_EXCEED,fl) end function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) @@ -122,14 +123,14 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT) + return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl) end i = 1 if !iszero(fr) while i < maxiters if nextfloat_tdir(left, prob.u0...) == right - return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fx) end mid = (fr * left - fl * right) / (fr - fl) for i in 1:10 @@ -156,7 +157,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) fm = f(mid) if iszero(fm) right = mid @@ -171,5 +172,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED) + return BracketingSolution(left, right, MAXITERS_EXCEED,fl) end diff --git a/src/solve.jl b/src/solve.jl index 2dca2fa7f..d3ad620d0 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -2,9 +2,11 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) solver = init(prob, alg, args...; kwargs...) - sol, resid = solve!(solver) + sol = solve!(solver) if typeof(sol) <: NewtonSolution - return SciMLBase.build_solution(prob, alg, sol.u, resid;retcode=:Success) + SciMLBase.build_solution(prob, alg, getsolution(sol), sol.resid;retcode=Symbol(sol.retcode)) + else + SciMLBase.build_solution(prob, alg, get_solution(sol),sol.resid;retcode=Symbol(sol.retcode),left = sol.left,right = sol.right) end end @@ -70,7 +72,7 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @set! solver.retcode = MAXITERS_EXCEED end sol = get_solution(solver) - return sol, solver.fu + return sol end """ @@ -105,11 +107,11 @@ end Form solution object from solver types """ function get_solution(solver::BracketingImmutableSolver) - return BracketingSolution(solver.left, solver.right, solver.retcode) + return BracketingSolution(solver.left, solver.right, solver.retcode, solver.fl) end function get_solution(solver::NewtonImmutableSolver) - return NewtonSolution(solver.u, solver.retcode) + return NewtonSolution(solver.u, solver.retcode, solver.fu) end """ diff --git a/src/types.jl b/src/types.jl index e7d38956e..ba4708b65 100644 --- a/src/types.jl +++ b/src/types.jl @@ -48,15 +48,17 @@ end # typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) # end -struct BracketingSolution{uType} +struct BracketingSolution{uType,resType} left::uType right::uType retcode::Retcode + resid::resType end -struct NewtonSolution{uType} +struct NewtonSolution{uType,resType} u::uType retcode::Retcode + resid::resType end function sync_residuals!(solver::BracketingImmutableSolver) From 17898cf5fc24a2d8e8765421d2c5b07196726d58 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Wed, 3 Feb 2021 23:48:24 +0530 Subject: [PATCH 5/9] add prob to integrator struct --- src/types.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index ba4708b65..d25f52baa 100644 --- a/src/types.jl +++ b/src/types.jl @@ -6,7 +6,7 @@ FLOATING_POINT_LIMIT end -struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractImmutableNonlinearSolver +struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -20,6 +20,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp retcode::Retcode cache::cacheType iip::Bool + prob::probType end # function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) @@ -27,7 +28,7 @@ end # typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) # end -struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType} <: AbstractImmutableNonlinearSolver +struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -41,6 +42,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT tol::tolType cache::cacheType iip::Bool + prob::probType end # function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip From 254ce8bd25c9ee9ff6f2bb575a10c62683451b95 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Wed, 3 Feb 2021 23:48:52 +0530 Subject: [PATCH 6/9] return build_solution in scalar.jl --- src/scalar.jl | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 8f32ce5dd..a0467ffc2 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -14,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a fx = f(x) dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx) end - iszero(fx) && return NewtonSolution(x, DEFAULT, fx) + iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) Δx = dfx \ fx x -= Δx if isapprox(x, xo, atol=atol, rtol=rtol) - return NewtonSolution(x, DEFAULT, fx) + return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) end xo = x end - return NewtonSolution(x, MAXITERS_EXCEED, fx) + return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED)) end function scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -33,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...) newprob = NonlinearProblem(f, u0, p; prob.kwargs...) sol = solve(newprob, alg, args...; kwargs...) - uu = getsolution(sol) + uu = sol.u if p isa Number f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) else @@ -51,39 +51,43 @@ end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) + #return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) + #return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) end # avoid ambiguities for Alg in [Bisection] @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) + return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) end end -function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) end i = 1 if !iszero(fr) while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -102,7 +106,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -114,23 +118,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED,fl) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) end -function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) +function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 fl, fr = f(left), f(right) if iszero(fl) - return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) end i = 1 if !iszero(fr) while i < maxiters if nextfloat_tdir(left, prob.u0...) == right - return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fx) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) end mid = (fr * left - fl * right) / (fr - fl) for i in 1:10 @@ -157,7 +161,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 while i < maxiters mid = (left + right) / 2 - (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl) + (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) fm = f(mid) if iszero(fm) right = mid @@ -172,5 +176,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10 i += 1 end - return BracketingSolution(left, right, MAXITERS_EXCEED,fl) + return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) end From 4119a8e22f283634931a9f9b0fb8efab201d7bcf Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Wed, 3 Feb 2021 23:49:25 +0530 Subject: [PATCH 7/9] handle build_solution in solve.jl --- src/solve.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index d3ad620d0..43f9c550c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -3,11 +3,6 @@ function SciMLBase.solve(prob::NonlinearProblem, kwargs...) solver = init(prob, alg, args...; kwargs...) sol = solve!(solver) - if typeof(sol) <: NewtonSolution - SciMLBase.build_solution(prob, alg, getsolution(sol), sol.resid;retcode=Symbol(sol.retcode)) - else - SciMLBase.build_solution(prob, alg, get_solution(sol),sol.resid;retcode=Symbol(sol.retcode),left = sol.left,right = sol.right) - end end function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; @@ -34,7 +29,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracket fl = f(left, p) fr = f(right, p) cache = alg_cache(alg, left, right,p, Val(iip)) - return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip) + return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob) end function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...; @@ -59,7 +54,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonA fu = f(u, p) end cache = alg_cache(alg, f, u, p, Val(iip)) - return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip) + return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob) end function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @@ -72,7 +67,11 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) @set! solver.retcode = MAXITERS_EXCEED end sol = get_solution(solver) - return sol + if typeof(sol) <: NewtonSolution + SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol), sol.resid;retcode=Symbol(sol.retcode)) + else + SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol),sol.resid;retcode=Symbol(sol.retcode),left = sol.left,right = sol.right) + end end """ From 973d04166d97472a10be0cfa7a925194cd1381f0 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Wed, 3 Feb 2021 23:49:46 +0530 Subject: [PATCH 8/9] Fix local var warnings --- test/runtests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8ab368153..c021c8a86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,13 +23,13 @@ end f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0] sf, su0 = (u,p) -> u * u - 2, 1.0 sol = benchmark_immutable(f, u0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_mutable(f, u0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test all(sol.u .* sol.u .- 2 .< 1e-9) sol = benchmark_scalar(sf, su0) -@test sol.retcode === NonlinearSolve.DEFAULT +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) @test sol.u * sol.u - 2 < 1e-9 @test (@ballocated benchmark_immutable($f, $u0)) == 0 @@ -117,6 +117,7 @@ probN = NonlinearProblem(f, u0) @test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] ≈ sqrt(2.0) for u0 in [1.0, [1, 1.0]] + local f, probN, sol f = (u, p) -> u .* u .- 2.0 probN = NonlinearProblem(f, u0) sol = sqrt(2) * u0 From f081339065f55e0250d083d472d23501b24c4c58 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Thu, 4 Feb 2021 00:09:47 +0530 Subject: [PATCH 9/9] Remove local algo solution types --- src/scalar.jl | 3 +-- src/solve.jl | 21 +++------------------ src/types.jl | 17 +---------------- 3 files changed, 5 insertions(+), 36 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index a0467ffc2..5b3c25966 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -52,12 +52,11 @@ end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) - #return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) + end function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) - #return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid) end # avoid ambiguities diff --git a/src/solve.jl b/src/solve.jl index 43f9c550c..5235056bf 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -66,11 +66,10 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) if solver.iter == solver.maxiters @set! solver.retcode = MAXITERS_EXCEED end - sol = get_solution(solver) - if typeof(sol) <: NewtonSolution - SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol), sol.resid;retcode=Symbol(sol.retcode)) + if typeof(solver) <: NewtonImmutableSolver + SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;retcode=Symbol(solver.retcode)) else - SciMLBase.build_solution(solver.prob, solver.alg, getsolution(sol),sol.resid;retcode=Symbol(sol.retcode),left = sol.left,right = sol.right) + SciMLBase.build_solution(solver.prob, solver.alg, solver.left,solver.fl;retcode=Symbol(solver.retcode),left = solver.left,right = solver.right) end end @@ -99,20 +98,6 @@ function mic_check(solver::NewtonImmutableSolver) solver end -""" - get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver}) - get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver}) - -Form solution object from solver types -""" -function get_solution(solver::BracketingImmutableSolver) - return BracketingSolution(solver.left, solver.right, solver.retcode, solver.fl) -end - -function get_solution(solver::NewtonImmutableSolver) - return NewtonSolution(solver.u, solver.retcode, solver.fu) -end - """ reinit!(solver, prob) diff --git a/src/types.jl b/src/types.jl index d25f52baa..9e673188a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -50,24 +50,9 @@ end # typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) # end -struct BracketingSolution{uType,resType} - left::uType - right::uType - retcode::Retcode - resid::resType -end - -struct NewtonSolution{uType,resType} - u::uType - retcode::Retcode - resid::resType -end function sync_residuals!(solver::BracketingImmutableSolver) @set! solver.fl = solver.f(solver.left, solver.p) @set! solver.fr = solver.f(solver.right, solver.p) solver -end - -getsolution(sol::NewtonSolution) = sol.u -getsolution(sol::BracketingSolution) = sol.left +end \ No newline at end of file