From bfc439bf84c229dddea7fb63698eec631728f5b7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 14:26:47 -0700 Subject: [PATCH 1/4] Make explicit imports for extensions --- .../SimpleNonlinearSolveChainRulesCoreExt.jl | 11 ++-- ...leNonlinearSolvePolyesterForwardDiffExt.jl | 3 +- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 62 +++++++++++-------- .../SimpleNonlinearSolveStaticArraysExt.jl | 2 +- .../ext/SimpleNonlinearSolveTrackerExt.jl | 18 ++++-- .../ext/SimpleNonlinearSolveZygoteExt.jl | 3 +- .../src/SimpleNonlinearSolve.jl | 6 +- lib/SimpleNonlinearSolve/src/utils.jl | 10 +-- 8 files changed, 66 insertions(+), 49 deletions(-) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index dc84cb3e8..814af8f39 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -1,14 +1,17 @@ module SimpleNonlinearSolveChainRulesCoreExt -using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve +using ChainRulesCore: ChainRulesCore, NoTangent +using DiffEqBase: DiffEqBase +using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem +using SimpleNonlinearSolve: SimpleNonlinearSolve # The expectation here is that no-one is using this directly inside a GPU kernel. We can # eventually lift this requirement using a custom adjoint function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up), - prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; - kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p, - SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...) + ChainRulesOriginator(), alg, args...; kwargs...) function ∇__internal_solve_up(Δ) ∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl index 81cee481d..aa38d1c4e 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl @@ -1,6 +1,7 @@ module SimpleNonlinearSolvePolyesterForwardDiffExt -using SimpleNonlinearSolve, PolyesterForwardDiff +using PolyesterForwardDiff: PolyesterForwardDiff +using SimpleNonlinearSolve: SimpleNonlinearSolve @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index e0bbda27e..f9bfe0965 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -1,60 +1,68 @@ module SimpleNonlinearSolveReverseDiffExt -using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve -import ReverseDiff: TrackedArray, TrackedReal -import SimpleNonlinearSolve: __internal_solve_up +using ArrayInterface: ArrayInterface +using DiffEqBase: DiffEqBase +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal +using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem +using SimpleNonlinearSolve: SimpleNonlinearSolve -function __internal_solve_up( - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, - p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) end -function __internal_solve_up( - prob::NonlinearProblem, sensealg, u0, u0_changed, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) end -function __internal_solve_up( - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, - p, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) end -function __internal_solve_up(prob::NonlinearProblem, sensealg, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return __internal_solve_up( + return SimpleNonlinearSolve.__internal_solve_up( prob, sensealg, ArrayInterface.aos_to_soa(u0), true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) end -function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return __internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) + return SimpleNonlinearSolve.__internal_solve_up( + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; + kwargs...) end -function __internal_solve_up(prob::NonlinearProblem, sensealg, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) - return __internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) + return SimpleNonlinearSolve.__internal_solve_up( + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; + kwargs...) end -ReverseDiff.@grad function __internal_solve_up( - prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) +ReverseDiff.@grad function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = DiffEqBase._solve_adjoint( prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), - SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...) - function ∇__internal_solve_up(_args...) + ReverseDiffOriginator(), alg, args...; kwargs...) + function ∇SimpleNonlinearSolve.__internal_solve_up(_args...) ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) end - return Array(out), ∇__internal_solve_up + return Array(out), ∇SimpleNonlinearSolve.__internal_solve_up end end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveStaticArraysExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveStaticArraysExt.jl index 90318a82a..c865084ce 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveStaticArraysExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveStaticArraysExt.jl @@ -1,6 +1,6 @@ module SimpleNonlinearSolveStaticArraysExt -using SimpleNonlinearSolve +using SimpleNonlinearSolve: SimpleNonlinearSolve @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index 61ce14645..7b35de09b 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -1,8 +1,12 @@ module SimpleNonlinearSolveTrackerExt -using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker +using DiffEqBase: DiffEqBase +using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem +using SimpleNonlinearSolve: SimpleNonlinearSolve +using Tracker: Tracker, TrackedArray -function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) return Tracker.track( SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, @@ -10,21 +14,23 @@ function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem, end function SimpleNonlinearSolve.__internal_solve_up( - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, - p::TrackedArray, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) return Tracker.track( SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem, +function SimpleNonlinearSolve.__internal_solve_up( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) return Tracker.track( SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) end -Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem, +Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( + _prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...) u0, p = Tracker.data(u0_), Tracker.data(p_) prob = remake(_prob; u0, p) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveZygoteExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveZygoteExt.jl index b29a1529a..559930b08 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveZygoteExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveZygoteExt.jl @@ -1,6 +1,7 @@ module SimpleNonlinearSolveZygoteExt -import SimpleNonlinearSolve, Zygote +using SimpleNonlinearSolve: SimpleNonlinearSolve +using Zygote: Zygote SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 3a4e5ebdb..e38aaa54b 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -8,14 +8,12 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat import DiffEqBase: AbstractNonlinearTerminationMode, AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, - NonlinearSafeTerminationReturnCode, get_termination_mode, - NONLINEARSOLVE_DEFAULT_NORM + AbstractSafeBestNonlinearTerminationMode, NONLINEARSOLVE_DEFAULT_NORM import DiffResults import ForwardDiff: Dual import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val - import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size + import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size end @reexport using ADTypes, SciMLBase diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 76e91fcb4..43689a29d 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -77,7 +77,7 @@ except `cache` (& `J` if not nothing) are mutated. function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, X} if isinplace(f) _f = (du, u) -> f(du, u, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) f.jac(J, x, p) _f(y, x) return y, J @@ -97,7 +97,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, end else _f = Base.Fix2(f, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return _f(x), f.jac(x, p) elseif ad isa AutoForwardDiff if ArrayInterface.can_setindex(x) @@ -124,7 +124,7 @@ end function __polyester_forwarddiff_jacobian! end function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F} - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return f(x, p), f.jac(x, p) elseif ad isa AutoForwardDiff T = typeof(__standard_tag(ad.tag, x)) @@ -152,7 +152,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray} if isinplace(f) _f = (du, u) -> f(du, u, p) J = similar(y, length(y), length(x)) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return J, nothing elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff return J, __get_jacobian_config(ad, _f, y, x) @@ -163,7 +163,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray} end else _f = Base.Fix2(f, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return nothing, nothing elseif ad isa AutoForwardDiff J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing From c5d1a7d7a190de274eebf4c42e2fdbe2cabebdad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 14:27:43 -0700 Subject: [PATCH 2/4] Run formatter --- lib/SimpleNonlinearSolve/.JuliaFormatter.toml | 3 +- .../SimpleNonlinearSolveChainRulesCoreExt.jl | 12 +-- ...leNonlinearSolvePolyesterForwardDiffExt.jl | 8 +- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 43 +++++------ .../ext/SimpleNonlinearSolveTrackerExt.jl | 27 +++---- .../src/SimpleNonlinearSolve.jl | 20 +++-- lib/SimpleNonlinearSolve/src/ad.jl | 34 +++++---- .../src/bracketing/alefeld.jl | 49 +++++------- .../src/bracketing/bisection.jl | 31 ++++---- .../src/bracketing/brent.jl | 26 +++---- .../src/bracketing/falsi.jl | 23 +++--- .../src/bracketing/itp.jl | 23 +++--- .../src/bracketing/ridder.jl | 27 ++++--- lib/SimpleNonlinearSolve/src/linesearch.jl | 14 ++-- .../src/nlsolve/broyden.jl | 12 +-- .../src/nlsolve/dfsane.jl | 8 +- .../src/nlsolve/halley.jl | 12 +-- .../src/nlsolve/klement.jl | 8 +- .../src/nlsolve/lbroyden.jl | 74 +++++++++---------- .../src/nlsolve/raphson.jl | 4 +- .../src/nlsolve/trustRegion.jl | 12 +-- lib/SimpleNonlinearSolve/src/utils.jl | 38 +++++----- .../test/core/23_test_problems_tests.jl | 11 ++- .../test/core/forward_ad_tests.jl | 21 +++--- .../test/core/least_squares_tests.jl | 3 +- .../test/core/matrix_resizing_tests.jl | 8 +- .../test/core/rootfind_tests.jl | 60 +++++++-------- .../test/gpu/cuda_tests.jl | 14 ++-- 28 files changed, 310 insertions(+), 315 deletions(-) diff --git a/lib/SimpleNonlinearSolve/.JuliaFormatter.toml b/lib/SimpleNonlinearSolve/.JuliaFormatter.toml index 4d06911d7..66c13bae3 100644 --- a/lib/SimpleNonlinearSolve/.JuliaFormatter.toml +++ b/lib/SimpleNonlinearSolve/.JuliaFormatter.toml @@ -1,4 +1,5 @@ style = "sciml" format_markdown = true annotate_untyped_fields_with_any = false -format_docstrings = true \ No newline at end of file +format_docstrings = true +join_lines_based_on_source = false diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index 814af8f39..2987f1c5b 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -8,14 +8,14 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve # The expectation here is that no-one is using this directly inside a GPU kernel. We can # eventually lift this requirement using a custom adjoint function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up), - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) - out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p, - ChainRulesOriginator(), alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) + out, ∇internal = DiffEqBase._solve_adjoint( + prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...) function ∇__internal_solve_up(Δ) ∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) - return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), - ∂args...) + return ( + ∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...) end return out, ∇__internal_solve_up end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl index aa38d1c4e..ac898ac16 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl @@ -5,14 +5,14 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true -@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x, - chunksize) where {F} +@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!( + f!::F, y, J, x, chunksize) where {F} PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize) return J end -@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x, - chunksize) where {F} +@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!( + f::F, J, x, chunksize) where {F} PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize) return J end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index f9bfe0965..a6a1c2dbf 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -9,52 +9,53 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve function SimpleNonlinearSolve.__internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, - p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, - p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) return SimpleNonlinearSolve.__internal_solve_up( prob, sensealg, ArrayInterface.aos_to_soa(u0), true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, - p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, + u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) return SimpleNonlinearSolve.__internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; - kwargs...) + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + true, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + sensealg, u0::AbstractArray{<:TrackedReal}, + u0_changed, p, p_changed, alg, args...; kwargs...) return SimpleNonlinearSolve.__internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; - kwargs...) + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + true, alg, args...; kwargs...) end ReverseDiff.@grad function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = DiffEqBase._solve_adjoint( prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), ReverseDiffOriginator(), alg, args...; kwargs...) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index 7b35de09b..b49bd78cc 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -6,27 +6,24 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve using Tracker: Tracker, TrackedArray function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) - return Tracker.track( - SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, - p, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return Tracker.track( - SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, - p, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return Tracker.track( - SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed, - p, p_changed, alg, args...; kwargs...) + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, + u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) end Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( @@ -34,8 +31,8 @@ Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...) u0, p = Tracker.data(u0_), Tracker.data(p_) prob = remake(_prob; u0, p) - out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p, - SciMLBase.TrackerOriginator(), alg, args...; kwargs...) + out, ∇internal = DiffEqBase._solve_adjoint( + prob, sensealg, u0, p, SciMLBase.TrackerOriginator(), alg, args...; kwargs...) function ∇__internal_solve_up(Δ) ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index e38aaa54b..dfd650c4d 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -56,17 +56,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; end # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms -function SciMLBase.solve( - prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, +function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end new_u0 = u0 !== nothing ? u0 : prob.u0 new_p = p !== nothing ? p : prob.p - return __internal_solve_up( - prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing, - alg, args...; prob.kwargs..., kwargs...) + return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, + p === nothing, alg, args...; prob.kwargs..., kwargs...) end function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, @@ -78,10 +76,10 @@ end @setup_workload begin for T in (Float32, Float64) prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) - prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, - T.([1.0, 1.0, 1.0]), T(2)) - prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, - T.([1.0, 1.0, 1.0]), T(2)) + prob_no_brack_iip = NonlinearProblem{true}( + (du, u, p) -> du .= u .* u .- p, T.([1.0, 1.0, 1.0]), T(2)) + prob_no_brack_oop = NonlinearProblem{false}( + (u, p) -> u .* u .- p, T.([1.0, 1.0, 1.0]), T(2)) algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(), SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2)] @@ -101,8 +99,8 @@ end end end - prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, - T.((0.0, 2.0)), T(2)) + prob_brack = IntervalNonlinearProblem{false}( + (u, p) -> u * u - p, T.((0.0, 2.0)), T(2)) algs = [Bisection(), Falsi(), Ridder(), Brent(), Alefeld(), ITP()] @compile_workload begin for alg in algs diff --git a/lib/SimpleNonlinearSolve/src/ad.jl b/lib/SimpleNonlinearSolve/src/ad.jl index d4e091c43..f42651bfe 100644 --- a/lib/SimpleNonlinearSolve/src/ad.jl +++ b/lib/SimpleNonlinearSolve/src/ad.jl @@ -1,7 +1,9 @@ function SciMLBase.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, - iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip} + prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; + kwargs...) where {T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) return SciMLBase.build_solution( @@ -9,9 +11,11 @@ function SciMLBase.solve( end function SciMLBase.solve( - prob::NonlinearLeastSquaresProblem{<:AbstractArray, - iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}}, - alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip} + prob::NonlinearLeastSquaresProblem{ + <:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}}, + alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; + kwargs...) where {T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) return SciMLBase.build_solution( @@ -21,13 +25,16 @@ end for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder) @eval begin function SciMLBase.solve( - prob::IntervalNonlinearProblem{uType, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip} + prob::IntervalNonlinearProblem{ + uType, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::$(algType), + args...; + kwargs...) where {uType, T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, - sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials), + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, + sol.original, left = Dual{T, V, P}(sol.left, partials), right = Dual{T, V, P}(sol.right, partials)) end end @@ -125,9 +132,8 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. else _F = @closure (u, p) -> begin T = promote_type(eltype(u), eltype(p)) - res = DiffResults.DiffResult( - similar(u, T, size(sol.resid)), similar( - u, T, length(sol.resid), length(u))) + res = DiffResults.DiffResult(similar(u, T, size(sol.resid)), + similar(u, T, length(sol.resid), length(u))) ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u) return reshape( 2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res), diff --git a/lib/SimpleNonlinearSolve/src/bracketing/alefeld.jl b/lib/SimpleNonlinearSolve/src/bracketing/alefeld.jl index 3b89751a7..55020c8a6 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/alefeld.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/alefeld.jl @@ -15,12 +15,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; c = a - (b - a) / (f(b) - f(a)) * f(a) fc = f(c) - (a == c || b == c) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, - left = a, right = b) - iszero(fc) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, left = a, - right = b) + (a == c || b == c) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = a, right = b) + iszero(fc) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b) a, b, d = _bracket(f, a, b, c) e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e) @@ -38,12 +36,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; end ē, fc = d, f(c) (a == c || b == c) && - return build_solution( - prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, - left = a, right = b) - iszero(fc) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, + return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = a, right = b) + iszero(fc) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b) ā, b̄, d̄ = _bracket(f, a, b, c) # The second bracketing block @@ -58,12 +54,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; end fc = f(c) (ā == c || b̄ == c) && - return build_solution( - prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, - left = ā, right = b̄) - iszero(fc) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, + return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = ā, right = b̄) + iszero(fc) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄) ā, b̄, d̄ = _bracket(f, ā, b̄, c) # The third bracketing block @@ -78,12 +72,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; end fc = f(c) (ā == c || b̄ == c) && - return build_solution( - prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, - left = ā, right = b̄) - iszero(fc) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, + return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = ā, right = b̄) + iszero(fc) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄) ā, b̄, d = _bracket(f, ā, b̄, c) # The last bracketing block @@ -93,12 +85,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; e = d c = 0.5 * (ā + b̄) fc = f(c) - (ā == c || b̄ == c) && - return build_solution(prob, alg, c, fc; - retcode = ReturnCode.FloatingPointLimit, left = ā, right = b̄) - iszero(fc) && - return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success, - left = ā, right = b̄) + (ā == c || b̄ == c) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, + left = ā, right = b̄) + iszero(fc) && return build_solution( + prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄) a, b, d = _bracket(f, ā, b̄, c) end end @@ -112,8 +103,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...; fc = f(c) # Reuturn solution when run out of max interation - return build_solution(prob, alg, c, fc; retcode = ReturnCode.MaxIters, - left = a, right = b) + return build_solution( + prob, alg, c, fc; retcode = ReturnCode.MaxIters, left = a, right = b) end # Define subrotine function bracket, check fc before bracket to return solution diff --git a/lib/SimpleNonlinearSolve/src/bracketing/bisection.jl b/lib/SimpleNonlinearSolve/src/bracketing/bisection.jl index acadf6aa1..d55a0ce5e 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/bisection.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/bisection.jl @@ -19,25 +19,24 @@ A common bisection method. exact_right::Bool = false end -function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...; - maxiters = 1000, abstol = nothing, kwargs...) +function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, + args...; maxiters = 1000, abstol = nothing, kwargs...) @assert !isinplace(prob) "`Bisection` only supports OOP problems." f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, - promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) + abstol = __get_tolerance( + nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) - return build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, - left, right) + return build_solution( + prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right) end if iszero(fr) return build_solution( - prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, - left, right) + prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right) end i = 1 @@ -49,8 +48,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args... retcode = ReturnCode.FloatingPointLimit) fm = f(mid) if abs((right - left) / 2) < abstol - return build_solution(prob, alg, mid, fm; retcode = ReturnCode.Success, - left, right) + return build_solution( + prob, alg, mid, fm; retcode = ReturnCode.Success, left, right) end if iszero(fm) right = mid @@ -67,8 +66,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args... end end - sol, i, left, right, fl, fr = __bisection(left, right, fl, fr, f; abstol, - maxiters = maxiters - i, prob, alg) + sol, i, left, right, fl, fr = __bisection( + left, right, fl, fr, f; abstol, maxiters = maxiters - i, prob, alg) sol !== nothing && return sol @@ -81,15 +80,15 @@ function __bisection(left, right, fl, fr, f::F; abstol, maxiters, prob, alg) whe while i < maxiters mid = (left + right) / 2 if (mid == left || mid == right) - sol = build_solution(prob, alg, left, fl; left, right, - retcode = ReturnCode.FloatingPointLimit) + sol = build_solution( + prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit) break end fm = f(mid) if abs((right - left) / 2) < abstol - sol = build_solution(prob, alg, mid, fm; left, right, - retcode = ReturnCode.Success) + sol = build_solution( + prob, alg, mid, fm; left, right, retcode = ReturnCode.Success) break end diff --git a/lib/SimpleNonlinearSolve/src/bracketing/brent.jl b/lib/SimpleNonlinearSolve/src/bracketing/brent.jl index 89b2e60be..649286e03 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/brent.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/brent.jl @@ -13,18 +13,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; fl, fr = f(left), f(right) ϵ = eps(convert(typeof(fl), 1)) - abstol = __get_tolerance(nothing, abstol, - promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) + abstol = __get_tolerance( + nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) - return build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, - left, right) + return build_solution( + prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right) end if iszero(fr) return build_solution( - prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, - left, right) + prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right) end if abs(fl) < abs(fr) @@ -60,18 +59,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; (!cond && abs(c - d) ≤ ϵ) # Bisection method s = (left + right) / 2 - (s == left || s == right) && - return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.FloatingPointLimit, - left = left, right = right) + (s == left || s == right) && return SciMLBase.build_solution( + prob, alg, left, fl; retcode = ReturnCode.FloatingPointLimit, + left = left, right = right) cond = true else cond = false end fs = f(s) if abs((right - left) / 2) < abstol - return SciMLBase.build_solution(prob, alg, s, fs; - retcode = ReturnCode.Success, + return SciMLBase.build_solution( + prob, alg, s, fs; retcode = ReturnCode.Success, left = left, right = right) end if iszero(fs) @@ -105,8 +103,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; end end - sol, i, left, right, fl, fr = __bisection(left, right, fl, fr, f; abstol, - maxiters = maxiters - i, prob, alg) + sol, i, left, right, fl, fr = __bisection( + left, right, fl, fr, f; abstol, maxiters = maxiters - i, prob, alg) sol !== nothing && return sol diff --git a/lib/SimpleNonlinearSolve/src/bracketing/falsi.jl b/lib/SimpleNonlinearSolve/src/bracketing/falsi.jl index 896e07329..ee78b73fc 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/falsi.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/falsi.jl @@ -12,18 +12,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, - promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) + abstol = __get_tolerance( + nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) - return build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, - left, right) + return build_solution( + prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right) end if iszero(fr) return build_solution( - prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, - left, right) + prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right) end # Regula Falsi Steps @@ -44,8 +43,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; fm = f(mid) if abs((right - left) / 2) < abstol - return build_solution(prob, alg, mid, fm; left, right, - retcode = ReturnCode.Success) + return build_solution( + prob, alg, mid, fm; left, right, retcode = ReturnCode.Success) end if abs(fm) < abstol @@ -62,10 +61,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; end end - sol, i, left, right, fl, fr = __bisection(left, right, fl, fr, f; abstol, - maxiters = maxiters - i, prob, alg) + sol, i, left, right, fl, fr = __bisection( + left, right, fl, fr, f; abstol, maxiters = maxiters - i, prob, alg) sol !== nothing && return sol - return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, - left, right) + return SciMLBase.build_solution( + prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right) end diff --git a/lib/SimpleNonlinearSolve/src/bracketing/itp.jl b/lib/SimpleNonlinearSolve/src/bracketing/itp.jl index 2926dfd7a..2972d1c5c 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/itp.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/itp.jl @@ -58,18 +58,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, - promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) + abstol = __get_tolerance( + nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) - return build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, - left, right) + return build_solution( + prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right) end if iszero(fr) return build_solution( - prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, - left, right) + prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right) end ϵ = abstol #defining variables/cache @@ -110,8 +109,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...; end if abs((left - right) / 2) < ϵ - return build_solution(prob, alg, mid, f(mid); retcode = ReturnCode.Success, - left, right) + return build_solution( + prob, alg, mid, f(mid); retcode = ReturnCode.Success, left, right) end ## Update ## @@ -127,16 +126,16 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...; left = xp fl = yp else - return build_solution(prob, alg, xp, yps; retcode = ReturnCode.Success, - left = xp, right = xp) + return build_solution( + prob, alg, xp, yps; retcode = ReturnCode.Success, left = xp, right = xp) end i += 1 mid = (left + right) / 2 ϵ_s /= 2 if __nextfloat_tdir(left, prob.tspan...) == right - return build_solution(prob, alg, left, fl; left, right, - retcode = ReturnCode.FloatingPointLimit) + return build_solution( + prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit) end end diff --git a/lib/SimpleNonlinearSolve/src/bracketing/ridder.jl b/lib/SimpleNonlinearSolve/src/bracketing/ridder.jl index 3b23f4287..a974824c2 100644 --- a/lib/SimpleNonlinearSolve/src/bracketing/ridder.jl +++ b/lib/SimpleNonlinearSolve/src/bracketing/ridder.jl @@ -12,18 +12,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, - promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) + abstol = __get_tolerance( + nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) - return build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, - left, right) + return build_solution( + prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right) end if iszero(fr) return build_solution( - prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, - left, right) + prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right) end xo = oftype(left, Inf) @@ -37,15 +36,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; fm = f(mid) s = sqrt(fm^2 - fl * fr) if iszero(s) - return build_solution(prob, alg, left, fl; left, right, - retcode = ReturnCode.Failure) + return build_solution( + prob, alg, left, fl; left, right, retcode = ReturnCode.Failure) end x = mid + (mid - left) * sign(fl - fr) * fm / s fx = f(x) xo = x if abs((right - left) / 2) < abstol - return build_solution(prob, alg, mid, fm; retcode = ReturnCode.Success, - left, right) + return build_solution( + prob, alg, mid, fm; retcode = ReturnCode.Success, left, right) end if iszero(fx) right = x @@ -69,10 +68,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; end end - sol, i, left, right, fl, fr = __bisection(left, right, fl, fr, f; abstol, - maxiters = maxiters - i, prob, alg) + sol, i, left, right, fl, fr = __bisection( + left, right, fl, fr, f; abstol, maxiters = maxiters - i, prob, alg) sol !== nothing && return sol - return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, - left, right) + return SciMLBase.build_solution( + prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right) end diff --git a/lib/SimpleNonlinearSolve/src/linesearch.jl b/lib/SimpleNonlinearSolve/src/linesearch.jl index c33253f63..b8e830af2 100644 --- a/lib/SimpleNonlinearSolve/src/linesearch.jl +++ b/lib/SimpleNonlinearSolve/src/linesearch.jl @@ -37,8 +37,8 @@ end end (alg::LiFukushimaLineSearch)(prob, fu, u) = __generic_init(alg, prob, fu, u) -function (alg::LiFukushimaLineSearch)(prob, fu::Union{Number, SArray}, - u::Union{Number, SArray}) +function (alg::LiFukushimaLineSearch)( + prob, fu::Union{Number, SArray}, u::Union{Number, SArray}) (alg.nan_maxiters === missing || alg.nan_maxiters === nothing) && return __static_init(alg, prob, fu, u) @warn "`LiFukushimaLineSearch` with NaN checking is not non-allocating" maxlog=1 @@ -57,14 +57,16 @@ function __generic_init(alg::LiFukushimaLineSearch, prob, fu, u) nan_maxiters = ifelse(alg.nan_maxiters === missing, 5, alg.nan_maxiters) - return LiFukushimaLineSearchCache(ϕ, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1), - T(alg.sigma_2), T(alg.eta), T(alg.rho), T(true), nan_maxiters, alg.maxiters) + return LiFukushimaLineSearchCache( + ϕ, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1), T(alg.sigma_2), + T(alg.eta), T(alg.rho), T(true), nan_maxiters, alg.maxiters) end function __static_init(alg::LiFukushimaLineSearch, prob, fu, u) T = promote_type(eltype(fu), eltype(u)) - return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0), T(alg.beta), - T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), alg.maxiters) + return StaticLiFukushimaLineSearchCache( + prob.f, prob.p, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1), + T(alg.sigma_2), T(alg.eta), T(alg.rho), alg.maxiters) end function (cache::LiFukushimaLineSearchCache)(u, δu) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl b/lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl index 6fe121481..890578f5d 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl @@ -23,8 +23,8 @@ end __get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS) function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) fx = _get_fx(prob, x) T = promote_type(eltype(x), eltype(fx)) @@ -48,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...; @bb δJ⁻¹n = copy(x) @bb δJ⁻¹ = copy(J⁻¹) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) - ls_cache = __get_linesearch(alg) === Val(true) ? - LiFukushimaLineSearch()(prob, fx, x) : nothing + ls_cache = __get_linesearch(alg) === Val(true) ? LiFukushimaLineSearch()(prob, fx, x) : + nothing for _ in 1:maxiters @bb δx = J⁻¹ × vec(fprev) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl b/lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl index 9f092648d..7dd152277 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl @@ -50,8 +50,8 @@ end function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, M::Union{Int, Val} = Val(10), γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, nexp::Int = 2, η_strategy::F = (f_1, k, x, F) -> f_1 ./ k^2) where {F} - return SimpleDFSane{_unwrap_val(M)}(σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, - η_strategy) + return SimpleDFSane{_unwrap_val(M)}( + σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy) end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...; @@ -70,8 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... τ_min = T(alg.τ_min) τ_max = T(alg.τ_max) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) fx_norm = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp α_1 = one(T) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl index 934dc4763..e550258ef 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl @@ -24,8 +24,8 @@ A low-overhead implementation of Halley's Method. end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) isinplace(prob) && error("SimpleHalley currently only supports out-of-place nonlinear problems") @@ -34,8 +34,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; T = eltype(x) autodiff = __get_concrete_autodiff(prob, alg.autodiff) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) @bb xo = copy(x) @@ -59,8 +59,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; dfx else fact = lu(dfx; check = false) - !issuccess(fact) && return build_solution(prob, alg, x, fx; - retcode = ReturnCode.Unstable) + !issuccess(fact) && + return build_solution(prob, alg, x, fx; retcode = ReturnCode.Unstable) fact end diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/klement.jl b/lib/SimpleNonlinearSolve/src/nlsolve/klement.jl index c2c8b446f..8041ef4b8 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/klement.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/klement.jl @@ -7,14 +7,14 @@ method is non-allocating on scalar and static array problems. struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) T = eltype(x) fx = _get_fx(prob, x) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) @bb δx = copy(x) @bb fprev = copy(fx) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/lbroyden.jl b/lib/SimpleNonlinearSolve/src/nlsolve/lbroyden.jl index 600892edd..b34d4cddf 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/lbroyden.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/lbroyden.jl @@ -24,8 +24,8 @@ end __get_threshold(::SimpleLimitedMemoryBroyden{threshold}) where {threshold} = Val(threshold) __use_linesearch(::SimpleLimitedMemoryBroyden{Th, LS}) where {Th, LS} = Val(LS) -function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27), - linesearch = Val(false), alpha = nothing) +function SimpleLimitedMemoryBroyden(; + threshold::Union{Val, Int} = Val(27), linesearch = Val(false), alpha = nothing) return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha) end @@ -45,24 +45,25 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd end @views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden, - args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) + args...; abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) threshold = __get_threshold(alg) η = min(_unwrap_val(threshold), maxiters) # For scalar problems / if the threshold is larger than problem size just use Broyden if x isa Number || length(x) ≤ η - return SciMLBase.__solve(prob, SimpleBroyden(; linesearch = __use_linesearch(alg)), - args...; abstol, reltol, maxiters, termination_condition, kwargs...) + return SciMLBase.__solve( + prob, SimpleBroyden(; linesearch = __use_linesearch(alg)), args...; + abstol, reltol, maxiters, termination_condition, kwargs...) end fx = _get_fx(prob, x) U, Vᵀ = __init_low_rank_jacobian(x, fx, x isa StaticArray ? threshold : Val(η)) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) @bb xo = copy(x) @bb δx = copy(fx) @@ -74,8 +75,8 @@ end Tcache = __lbroyden_threshold_cache(x, x isa StaticArray ? threshold : Val(η)) @bb mat_cache = copy(x) - ls_cache = __use_linesearch(alg) === Val(true) ? - LiFukushimaLineSearch()(prob, fx, x) : nothing + ls_cache = __use_linesearch(alg) === Val(true) ? LiFukushimaLineSearch()(prob, fx, x) : + nothing for i in 1:maxiters α = ls_cache === nothing ? true : ls_cache(x, δx) @@ -125,8 +126,8 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo xo, δx, fo, δf = x, -fx, fx, fx - ls_cache = __use_linesearch(alg) === Val(true) ? - LiFukushimaLineSearch()(prob, fx, x) : nothing + ls_cache = __use_linesearch(alg) === Val(true) ? LiFukushimaLineSearch()(prob, fx, x) : + nothing T = promote_type(eltype(x), eltype(fx)) if alg.alpha === nothing @@ -138,8 +139,7 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo end converged, res = __unrolled_lbroyden_initial_iterations( - prob, xo, fo, δx, abstol, U, Vᵀ, - threshold, ls_cache, init_α) + prob, xo, fo, δx, abstol, U, Vᵀ, threshold, ls_cache, init_α) converged && return build_solution(prob, alg, res.x, res.fx; retcode = ReturnCode.Success) @@ -173,39 +173,39 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo return build_solution(prob, alg, xo, fo; retcode = ReturnCode.MaxIters) end -@generated function __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U, - Vᵀ, ::Val{threshold}, ls_cache, init_α) where {threshold} +@generated function __unrolled_lbroyden_initial_iterations( + prob, xo, fo, δx, abstol, U, Vᵀ, ::Val{threshold}, + ls_cache, init_α) where {threshold} calls = [] for i in 1:threshold static_idx, static_idx_p1 = Val(i - 1), Val(i) - push!(calls, - quote - α = ls_cache === nothing ? true : ls_cache(xo, δx) - x = xo .+ α .* δx - fx = prob.f(x, prob.p) - δf = fx - fo + push!(calls, quote + α = ls_cache === nothing ? true : ls_cache(xo, δx) + x = xo .+ α .* δx + fx = prob.f(x, prob.p) + δf = fx - fo - maximum(abs, fx) ≤ abstol && return true, (; x, fx, δx) + maximum(abs, fx) ≤ abstol && return true, (; x, fx, δx) - _U = __first_n_getindex(U, $(static_idx)) - _Vᵀ = __first_n_getindex(Vᵀ, $(static_idx)) + _U = __first_n_getindex(U, $(static_idx)) + _Vᵀ = __first_n_getindex(Vᵀ, $(static_idx)) - vᵀ = _restructure(x, _rmatvec!!(_U, _Vᵀ, vec(δx), init_α)) - mvec = _restructure(x, _matvec!!(_U, _Vᵀ, vec(δf), init_α)) + vᵀ = _restructure(x, _rmatvec!!(_U, _Vᵀ, vec(δx), init_α)) + mvec = _restructure(x, _matvec!!(_U, _Vᵀ, vec(δf), init_α)) - d = dot(vᵀ, δf) - δx = @. (δx - mvec) / d + d = dot(vᵀ, δf) + δx = @. (δx - mvec) / d - U = Base.setindex(U, vec(δx), $(i)) - Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), $(i)) + U = Base.setindex(U, vec(δx), $(i)) + Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), $(i)) - _U = __first_n_getindex(U, $(static_idx_p1)) - _Vᵀ = __first_n_getindex(Vᵀ, $(static_idx_p1)) - δx = -_restructure(fx, _matvec!!(_U, _Vᵀ, vec(fx), init_α)) + _U = __first_n_getindex(U, $(static_idx_p1)) + _Vᵀ = __first_n_getindex(Vᵀ, $(static_idx_p1)) + δx = -_restructure(fx, _matvec!!(_U, _Vᵀ, vec(fx), init_α)) - xo = x - fo = fx - end) + xo = x + fo = fx + end) end push!(calls, quote # Termination Check diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/raphson.jl b/lib/SimpleNonlinearSolve/src/nlsolve/raphson.jl index 9735d0c8c..7b419ce99 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/raphson.jl @@ -32,8 +32,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr @bb xo = copy(x) J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) for i in 1:maxiters fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl b/lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl index e6ccf6536..a19cf2c2d 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl @@ -56,8 +56,8 @@ scalar and static array problems. end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) x = __maybe_unaliased(prob.u0, alias_u0) T = eltype(real(x)) Δₘₐₓ = T(alg.max_trust_radius) @@ -88,8 +88,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args. J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p) fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J) - abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, - termination_condition) + abstol, reltol, tc_cache = init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition) # Set default trust region radius if not specified by user. Δₘₐₓ == 0 && (Δₘₐₓ = max(norm_fx, maximum(x) - minimum(x))) @@ -132,8 +132,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args. else Δ = t₁ * Δ shrink_counter += 1 - shrink_counter > max_shrink_times && return build_solution(prob, alg, x, fx; - retcode = ReturnCode.ShrinkThresholdExceeded) + shrink_counter > max_shrink_times && return build_solution( + prob, alg, x, fx; retcode = ReturnCode.ShrinkThresholdExceeded) end if r ≥ η₁ diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 43689a29d..333e54d3b 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -58,8 +58,8 @@ function __forwarddiff_jacobian_config( f::F, x::SArray, ck::ForwardDiff.Chunk{N}, tag) where {F, N} seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, eltype(x)}) duals = ForwardDiff.Dual{typeof(tag), eltype(x), N}.(x) - return ForwardDiff.JacobianConfig{typeof(tag), eltype(x), N, typeof(duals)}(seeds, - duals) + return ForwardDiff.JacobianConfig{typeof(tag), eltype(x), N, typeof(duals)}( + seeds, duals) end function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS} @@ -205,8 +205,8 @@ end function compute_jacobian_and_hessian(ad::AutoFiniteDiff, prob, _, x::Number) fx = prob.f(x, prob.p) - J_fn = x -> FiniteDiff.finite_difference_derivative(Base.Fix2(prob.f, prob.p), x, - ad.fdtype) + J_fn = x -> FiniteDiff.finite_difference_derivative( + Base.Fix2(prob.f, prob.p), x, ad.fdtype) dfx = J_fn(x) d2fx = FiniteDiff.finite_difference_derivative(J_fn, x, ad.fdtype) return fx, dfx, d2fx @@ -262,7 +262,8 @@ end @inline _restructure(y, x) = ArrayInterface.restructure(y, x) @inline function _get_fx(prob::NonlinearLeastSquaresProblem, x) - isinplace(prob) && prob.f.resid_prototype === nothing && + isinplace(prob) && + prob.f.resid_prototype === nothing && error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`") return _get_fx(prob.f, x, prob.p) end @@ -289,17 +290,16 @@ end # is meant for low overhead solvers, users can opt into the other termination modes but the # default is to use the least overhead version. function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing) - return init_termination_cache(prob, abstol, reltol, du, u, - AbsNormTerminationMode(Base.Fix1(maximum, abs))) + return init_termination_cache( + prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs))) end function init_termination_cache( prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing) - return init_termination_cache(prob, abstol, reltol, du, u, - AbsNormTerminationMode(Base.Fix2(norm, 2))) + return init_termination_cache( + prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2))) end -function init_termination_cache( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) T = promote_type(eltype(du), eltype(u)) abstol = __get_tolerance(u, abstol, T) @@ -316,23 +316,23 @@ function init_termination_cache( end function check_termination(tc_cache, fx, x, xo, prob, alg) - return check_termination(tc_cache, fx, x, xo, prob, alg, - DiffEqBase.get_termination_mode(tc_cache)) + return check_termination( + tc_cache, fx, x, xo, prob, alg, DiffEqBase.get_termination_mode(tc_cache)) end -function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractNonlinearTerminationMode) +function check_termination( + tc_cache, fx, x, xo, prob, alg, ::AbstractNonlinearTerminationMode) tc_cache(fx, x, xo) && return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) return nothing end -function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractSafeNonlinearTerminationMode) +function check_termination( + tc_cache, fx, x, xo, prob, alg, ::AbstractSafeNonlinearTerminationMode) tc_cache(fx, x, xo) && return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode) return nothing end -function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractSafeBestNonlinearTerminationMode) +function check_termination( + tc_cache, fx, x, xo, prob, alg, ::AbstractSafeBestNonlinearTerminationMode) if tc_cache(fx, x, xo) if isinplace(prob) prob.f(fx, x, prob.p) diff --git a/lib/SimpleNonlinearSolve/test/core/23_test_problems_tests.jl b/lib/SimpleNonlinearSolve/test/core/23_test_problems_tests.jl index 2180943fc..9625c6872 100644 --- a/lib/SimpleNonlinearSolve/test/core/23_test_problems_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/23_test_problems_tests.jl @@ -4,8 +4,8 @@ using LinearAlgebra, NonlinearProblemLibrary, DiffEqBase, Test problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts -function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4; - skip_tests = nothing) +function test_on_library( + problems, dicts, alg_ops, broken_tests, ϵ = 1e-4; skip_tests = nothing) for (idx, (problem, dict)) in enumerate(zip(problems, dicts)) x = dict["start"] res = similar(x) @@ -13,8 +13,8 @@ function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4; @testset "$idx: $(dict["title"])" begin for alg in alg_ops try - sol = solve(nlprob, alg; - termination_condition = AbsNormTerminationMode()) + sol = solve( + nlprob, alg; termination_condition = AbsNormTerminationMode()) problem(res, sol.u, nothing) skip = skip_tests !== nothing && idx in skip_tests[alg] @@ -51,8 +51,7 @@ end end @testitem "SimpleTrustRegion" setup=[RobustnessTesting] tags=[:core] begin - alg_ops = (SimpleTrustRegion(), - SimpleTrustRegion(; nlsolve_update_rule = Val(true))) + alg_ops = (SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true))) broken_tests = Dict(alg => Int[] for alg in alg_ops) broken_tests[alg_ops[1]] = [3, 15, 16, 21] diff --git a/lib/SimpleNonlinearSolve/test/core/forward_ad_tests.jl b/lib/SimpleNonlinearSolve/test/core/forward_ad_tests.jl index 50fc18e94..ab1db6cb0 100644 --- a/lib/SimpleNonlinearSolve/test/core/forward_ad_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/forward_ad_tests.jl @@ -41,8 +41,9 @@ export test_f, test_f!, jacobian_f, solve_with, __compatible end @testitem "ForwardDiff.jl Integration: Rootfinding" setup=[ForwardADRootfindingTesting] tags=[:core] begin - @testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(), - SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane()) us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2)) @@ -171,25 +172,27 @@ end function obj_4(p) prob_iip = NonlinearLeastSquaresProblem( NonlinearFunction{true}( - loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p) + loss_function!; resid_prototype = zeros(length(y_target))), + θ_init, + p) sol = solve(prob_iip, alg) return sum(abs2, sol.u) end function obj_5(p) ff = NonlinearFunction{true}( - loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!) - prob_iip = NonlinearLeastSquaresProblem( - ff, θ_init, p) + loss_function!; resid_prototype = zeros(length(y_target)), + jac = loss_function_jac!) + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) sol = solve(prob_iip, alg) return sum(abs2, sol.u) end function obj_6(p) ff = NonlinearFunction{true}( - loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!) - prob_iip = NonlinearLeastSquaresProblem( - ff, θ_init, p) + loss_function!; resid_prototype = zeros(length(y_target)), + vjp = loss_function_vjp!) + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) sol = solve(prob_iip, alg) return sum(abs2, sol.u) end diff --git a/lib/SimpleNonlinearSolve/test/core/least_squares_tests.jl b/lib/SimpleNonlinearSolve/test/core/least_squares_tests.jl index ef3a05504..005c463ff 100644 --- a/lib/SimpleNonlinearSolve/test/core/least_squares_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/least_squares_tests.jl @@ -29,7 +29,8 @@ end prob_iip = NonlinearLeastSquaresProblem( - NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), θ_init, x) + NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), + θ_init, x) @testset "Solver: $(nameof(typeof(solver)))" for solver in [ SimpleNewtonRaphson(AutoForwardDiff()), SimpleGaussNewton(AutoForwardDiff()), diff --git a/lib/SimpleNonlinearSolve/test/core/matrix_resizing_tests.jl b/lib/SimpleNonlinearSolve/test/core/matrix_resizing_tests.jl index 17cd3d674..3f7dfbb40 100644 --- a/lib/SimpleNonlinearSolve/test/core/matrix_resizing_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/matrix_resizing_tests.jl @@ -5,10 +5,10 @@ vecprob = NonlinearProblem(ff, vec(u0), p) prob = NonlinearProblem(ff, u0, p) - @testset "$(nameof(typeof(alg)))" for alg in (SimpleKlement(), SimpleBroyden(), - SimpleNewtonRaphson(), SimpleDFSane(), - SimpleLimitedMemoryBroyden(; threshold = Val(2)), - SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true))) + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleKlement(), SimpleBroyden(), SimpleNewtonRaphson(), SimpleDFSane(), + SimpleLimitedMemoryBroyden(; threshold = Val(2)), SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true))) @test vec(solve(prob, alg).u) ≈ solve(vecprob, alg).u end end diff --git a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl index ca0e26ef6..1ef0757b8 100644 --- a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl @@ -1,7 +1,7 @@ @testsetup module RootfindingTesting using Reexport -@reexport using AllocCheck, - LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase +@reexport using AllocCheck, LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, + DiffEqBase import PolyesterForwardDiff quadratic_f(u, p) = u .* u .- p @@ -14,16 +14,14 @@ function newton_fails(u, p) (0.21640425613334457 .+ 216.40425613334457 ./ (1 .+ (0.21640425613334457 .+ - 216.40425613334457 ./ - (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ 2.0) .- - 0.0011552453009332421u .- p + 216.40425613334457 ./ (1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ + 2.0) .- 0.0011552453009332421u .- p end const TERMINATION_CONDITIONS = [ NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(), AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(), - AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode() -] + AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode()] function benchmark_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F} prob = NonlinearProblem{false}(f, u0, p) @@ -40,14 +38,14 @@ export quadratic_f, quadratic_f!, quadratic_f2, newton_fails, TERMINATION_CONDIT end @testitem "First Order Methods" setup=[RootfindingTesting] tags=[:core] begin - @testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion, - (args...; kwargs...) -> SimpleTrustRegion(args...; nlsolve_update_rule = Val(true), - kwargs...)) + @testset "$(alg)" for alg in (SimpleNewtonRaphson, + SimpleTrustRegion, + (args...; kwargs...) -> SimpleTrustRegion( + args...; nlsolve_update_rule = Val(true), kwargs...)) @testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in ( - AutoFiniteDiff(), - AutoForwardDiff(), AutoPolyesterForwardDiff()) - @testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], - @SVector[1.0, 1.0], 1.0) + AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff()) + @testset "[OOP] u0: $(typeof(u0))" for u0 in ( + [1.0, 1.0], @SVector[1.0, 1.0], 1.0) u0 isa SVector && autodiff isa AutoPolyesterForwardDiff && continue sol = benchmark_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff)) @test SciMLBase.successful_retcode(sol) @@ -71,10 +69,10 @@ end end @testitem "SimpleHalley" setup=[RootfindingTesting] tags=[:core] begin - @testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in (AutoFiniteDiff(), - AutoForwardDiff()) - @testset "[OOP] u0: $(nameof(typeof(u0)))" for u0 in ([1.0, 1.0], - @SVector[1.0, 1.0], 1.0) + @testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in ( + AutoFiniteDiff(), AutoForwardDiff()) + @testset "[OOP] u0: $(nameof(typeof(u0)))" for u0 in ( + [1.0, 1.0], @SVector[1.0, 1.0], 1.0) sol = benchmark_nlsolve_oop(quadratic_f, u0; solver = SimpleHalley(; autodiff)) @test SciMLBase.successful_retcode(sol) @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) @@ -90,9 +88,9 @@ end end @testitem "Derivative Free Metods" setup=[RootfindingTesting] tags=[:core] begin - @testset "$(nameof(typeof(alg)))" for alg in [SimpleBroyden(), SimpleKlement(), - SimpleDFSane(), SimpleLimitedMemoryBroyden(), - SimpleBroyden(; linesearch = Val(true)), + @testset "$(nameof(typeof(alg)))" for alg in [ + SimpleBroyden(), SimpleKlement(), SimpleDFSane(), + SimpleLimitedMemoryBroyden(), SimpleBroyden(; linesearch = Val(true)), SimpleLimitedMemoryBroyden(; linesearch = Val(true))] @testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) sol = benchmark_nlsolve_oop(quadratic_f, u0; solver = alg) @@ -119,8 +117,9 @@ end u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0] p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - @testset "$(nameof(typeof(alg)))" for alg in (SimpleDFSane(), SimpleTrustRegion(), - SimpleHalley(), SimpleTrustRegion(; nlsolve_update_rule = Val(true))) + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleDFSane(), SimpleTrustRegion(), SimpleHalley(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true))) sol = benchmark_nlsolve_oop(newton_fails, u0, p; solver = alg) @test SciMLBase.successful_retcode(sol) @test all(abs.(newton_fails(sol.u, p)) .< 1e-9) @@ -135,9 +134,10 @@ end @testitem "Allocation Checks" setup=[RootfindingTesting] tags=[:core] begin if Sys.islinux() # Very slow on other OS - @testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(), - SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleLimitedMemoryBroyden(), - SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), SimpleHalley(), SimpleBroyden(), + SimpleKlement(), SimpleLimitedMemoryBroyden(), SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), SimpleDFSane(), SimpleBroyden(; linesearch = Val(true)), SimpleLimitedMemoryBroyden(; linesearch = Val(true))) @check_allocs nlsolve(prob, alg) = SciMLBase.solve(prob, alg; abstol = 1e-9) @@ -166,8 +166,8 @@ end end @testitem "Interval Nonlinear Problems" setup=[RootfindingTesting] tags=[:core] begin - @testset "$(nameof(typeof(alg)))" for alg in (Bisection(), Falsi(), Ridder(), Brent(), - ITP(), Alefeld()) + @testset "$(nameof(typeof(alg)))" for alg in ( + Bisection(), Falsi(), Ridder(), Brent(), ITP(), Alefeld()) tspan = (1.0, 20.0) function g(p) @@ -240,8 +240,8 @@ end end @testitem "Flipped Signs and Reversed Tspan" setup=[RootfindingTesting] tags=[:core] begin - @testset "$(nameof(typeof(alg)))" for alg in (Alefeld(), Bisection(), Falsi(), Brent(), - ITP(), Ridder()) + @testset "$(nameof(typeof(alg)))" for alg in ( + Alefeld(), Bisection(), Falsi(), Brent(), ITP(), Ridder()) f1(u, p) = u * u - p f2(u, p) = p - u * u diff --git a/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl b/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl index 39ac422a4..efc03403a 100644 --- a/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl +++ b/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl @@ -6,8 +6,9 @@ f(u, p) = u .* u .- 2 f!(du, u, p) = du .= u .* u .- 2 - @testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(), SimpleDFSane(), - SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(), SimpleHalley(), SimpleBroyden(; linesearch = Val(true)), SimpleLimitedMemoryBroyden(; linesearch = Val(true))) @@ -51,10 +52,11 @@ end prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0]) - @testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(), SimpleDFSane(), - SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)), - SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(), SimpleHalley(), - SimpleBroyden(; linesearch = Val(true)), + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(), + SimpleHalley(), SimpleBroyden(; linesearch = Val(true)), SimpleLimitedMemoryBroyden(; linesearch = Val(true))) @test begin try From efa09d8714c3e1ac39ee5b8dc89aa0281a2476b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 14:43:04 -0700 Subject: [PATCH 3/4] Add explicit imports --- lib/SimpleNonlinearSolve/Project.toml | 4 +- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 42 +++++++++---------- .../ext/SimpleNonlinearSolveTrackerExt.jl | 2 +- .../src/SimpleNonlinearSolve.jl | 36 ++++++++++------ .../test/core/aqua_tests.jl | 9 ---- .../test/core/qa_tests.jl | 23 ++++++++++ 6 files changed, 69 insertions(+), 47 deletions(-) delete mode 100644 lib/SimpleNonlinearSolve/test/core/aqua_tests.jl create mode 100644 lib/SimpleNonlinearSolve/test/core/qa_tests.jl diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 74b6cde87..0fbdfd614 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -45,6 +45,7 @@ ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DiffEqBase = "6.149" DiffResults = "1.1" +ExplicitImports = "1.5.0" FastClosures = "0.3.2" FiniteDiff = "2.22" ForwardDiff = "0.10.36" @@ -73,6 +74,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff", "ReverseDiff", "Tracker"] +test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index a6a1c2dbf..c5f0286f1 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -5,65 +5,61 @@ using DiffEqBase: DiffEqBase using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem using SimpleNonlinearSolve: SimpleNonlinearSolve +import SimpleNonlinearSolve: __internal_solve_up -function SimpleNonlinearSolve.__internal_solve_up( +function __internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up( +function __internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up( +function __internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return SimpleNonlinearSolve.__internal_solve_up( - prob, sensealg, ArrayInterface.aos_to_soa(u0), true, + return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up( +function __internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return SimpleNonlinearSolve.__internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) end -function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, +function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) - return SimpleNonlinearSolve.__internal_solve_up( - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) end -ReverseDiff.@grad function SimpleNonlinearSolve.__internal_solve_up( +ReverseDiff.@grad function __internal_solve_up( prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = DiffEqBase._solve_adjoint( prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), ReverseDiffOriginator(), alg, args...; kwargs...) - function ∇SimpleNonlinearSolve.__internal_solve_up(_args...) + function ∇__internal_solve_up(_args...) ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) end - return Array(out), ∇SimpleNonlinearSolve.__internal_solve_up + return Array(out), ∇__internal_solve_up end end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index b49bd78cc..85b84f80f 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -32,7 +32,7 @@ Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( u0, p = Tracker.data(u0_), Tracker.data(p_) prob = remake(_prob; u0, p) out, ∇internal = DiffEqBase._solve_adjoint( - prob, sensealg, u0, p, SciMLBase.TrackerOriginator(), alg, args...; kwargs...) + prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) function ∇__internal_solve_up(Δ) ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index dfd650c4d..1ff66b995 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -1,22 +1,31 @@ module SimpleNonlinearSolve -import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations +using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations @recompile_invalidations begin - using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff, - ForwardDiff, Reexport, LinearAlgebra, SciMLBase - - import DiffEqBase: AbstractNonlinearTerminationMode, - AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, NONLINEARSOLVE_DEFAULT_NORM - import DiffResults - import ForwardDiff: Dual - import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex - import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val - import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size + using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff + using ArrayInterface: ArrayInterface + using ConcreteStructs: @concrete + using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, + NONLINEARSOLVE_DEFAULT_NORM + using DiffResults: DiffResults + using FastClosures: @closure + using FiniteDiff: FiniteDiff + using ForwardDiff: ForwardDiff, Dual + using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, + mul!, norm, transpose + using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex + using Reexport: @reexport + using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction, + NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, + remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace, + _unwrap_val + using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size end -@reexport using ADTypes, SciMLBase +@reexport using SciMLBase abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @@ -110,6 +119,7 @@ end end end +export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKlement, SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion export Alefeld, Bisection, Brent, Falsi, ITP, Ridder diff --git a/lib/SimpleNonlinearSolve/test/core/aqua_tests.jl b/lib/SimpleNonlinearSolve/test/core/aqua_tests.jl deleted file mode 100644 index 364f51b59..000000000 --- a/lib/SimpleNonlinearSolve/test/core/aqua_tests.jl +++ /dev/null @@ -1,9 +0,0 @@ -@testitem "Aqua" tags=[:core] begin - using Aqua - - Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false) - Aqua.test_piracies(SimpleNonlinearSolve; - treat_as_own = [ - NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem]) - Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false) -end diff --git a/lib/SimpleNonlinearSolve/test/core/qa_tests.jl b/lib/SimpleNonlinearSolve/test/core/qa_tests.jl new file mode 100644 index 000000000..fbdb813ee --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/core/qa_tests.jl @@ -0,0 +1,23 @@ +@testitem "Aqua" tags=[:core] begin + using Aqua + + Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false) + Aqua.test_piracies(SimpleNonlinearSolve; + treat_as_own = [ + NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem]) + Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false) +end + +@testitem "Explicit Imports" tags=[:core] begin + import PolyesterForwardDiff, ReverseDiff, Tracker, StaticArrays, Zygote + + using ExplicitImports + + @test check_no_implicit_imports( + SimpleNonlinearSolve; skip = (SimpleNonlinearSolve, Base, Core, SciMLBase)) === + nothing + + @test check_no_stale_explicit_imports(SimpleNonlinearSolve) === nothing + + @test check_all_qualified_accesses_via_owners(SimpleNonlinearSolve) === nothing +end From 248088c258d39bd7adf2801ab8240f17d7295436 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 15:01:59 -0700 Subject: [PATCH 4/4] Resolve ambiguity --- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 94 ++++++++++--------- .../ext/SimpleNonlinearSolveTrackerExt.jl | 74 ++++++++------- .../test/core/exotic_type_tests.jl | 3 +- 3 files changed, 88 insertions(+), 83 deletions(-) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index c5f0286f1..249bbbedb 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -7,59 +7,61 @@ using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresP using SimpleNonlinearSolve: SimpleNonlinearSolve import SimpleNonlinearSolve: __internal_solve_up -function __internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) -end +for pType in (NonlinearProblem, NonlinearLeastSquaresProblem) + @eval begin + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed, + p::TrackedArray, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) + end -function __internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) -end + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, + p::TrackedArray, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) + end -function __internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) -end + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, + u0_changed, p, p_changed, alg, args...; kwargs...) + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) + end -function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, - p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true, - ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) -end + function __internal_solve_up( + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) + return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true, + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) + end -function __internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, - u0_changed, p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) - return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), - true, alg, args...; kwargs...) -end + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) + return __internal_solve_up( + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + true, alg, args...; kwargs...) + end -function __internal_solve_up(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0::AbstractArray{<:TrackedReal}, - u0_changed, p, p_changed, alg, args...; kwargs...) - return __internal_solve_up(prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), - true, alg, args...; kwargs...) -end + function __internal_solve_up( + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, + u0_changed, p, p_changed, alg, args...; kwargs...) + return __internal_solve_up( + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), + true, alg, args...; kwargs...) + end -ReverseDiff.@grad function __internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) - out, ∇internal = DiffEqBase._solve_adjoint( - prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), - ReverseDiffOriginator(), alg, args...; kwargs...) - function ∇__internal_solve_up(_args...) - ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) - return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) + ReverseDiff.@grad function __internal_solve_up( + prob::$(pType), sensealg, u0, u0_changed, + p, p_changed, alg, args...; kwargs...) + out, ∇internal = DiffEqBase._solve_adjoint( + prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), + ReverseDiffOriginator(), alg, args...; kwargs...) + function ∇__internal_solve_up(_args...) + ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) + end + return Array(out), ∇__internal_solve_up + end end - return Array(out), ∇__internal_solve_up end end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index 85b84f80f..a212b220e 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -1,45 +1,49 @@ module SimpleNonlinearSolveTrackerExt using DiffEqBase: DiffEqBase -using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem +using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake using SimpleNonlinearSolve: SimpleNonlinearSolve using Tracker: Tracker, TrackedArray -function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...) - return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) -end - -function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0::TrackedArray, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) -end - -function SimpleNonlinearSolve.__internal_solve_up( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, - u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...) - return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, - u0, u0_changed, p, p_changed, alg, args...; kwargs...) -end - -Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( - _prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...) - u0, p = Tracker.data(u0_), Tracker.data(p_) - prob = remake(_prob; u0, p) - out, ∇internal = DiffEqBase._solve_adjoint( - prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) - - function ∇__internal_solve_up(Δ) - ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) - return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) +for pType in (NonlinearProblem, NonlinearLeastSquaresProblem) + @eval begin + function SimpleNonlinearSolve.__internal_solve_up( + prob::$(pType), sensealg, u0::TrackedArray, + u0_changed, p, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) + end + + function SimpleNonlinearSolve.__internal_solve_up( + prob::$(pType), sensealg, u0::TrackedArray, u0_changed, + p::TrackedArray, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) + end + + function SimpleNonlinearSolve.__internal_solve_up( + prob::$(pType), sensealg, u0, u0_changed, + p::TrackedArray, p_changed, alg, args...; kwargs...) + return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, + u0, u0_changed, p, p_changed, alg, args...; kwargs...) + end + + Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up( + _prob::$(pType), sensealg, u0_, u0_changed, + p_, p_changed, alg, args...; kwargs...) + u0, p = Tracker.data(u0_), Tracker.data(p_) + prob = remake(_prob; u0, p) + out, ∇internal = DiffEqBase._solve_adjoint( + prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) + + function ∇__internal_solve_up(Δ) + ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) + end + + return out, ∇__internal_solve_up + end end - - return out, ∇__internal_solve_up end end diff --git a/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl b/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl index fff77a3d4..302d2402e 100644 --- a/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl @@ -16,8 +16,7 @@ end using SimpleNonlinearSolve, LinearAlgebra for alg in [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(), - SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2), - SimpleHalley()] + SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2), SimpleHalley()] sol = solve(prob_oop_bf, alg) @test norm(sol.resid, Inf) < 1e-6 @test SciMLBase.successful_retcode(sol.retcode)