diff --git a/Project.toml b/Project.toml index 4d78f7d..78f3fcc 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ SimpleBatchedNonlinearSolveExt = "NNlib" [compat] ArrayInterface = "6, 7" -DiffEqBase = "6.114" FiniteDiff = "2" ForwardDiff = "0.10.3" NNlib = "0.8" diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl index 4d5b455..07c117f 100644 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -1,6 +1,7 @@ module SimpleBatchedNonlinearSolveExt -using ArrayInterface, LinearAlgebra, SimpleNonlinearSolve, SciMLBase +using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase + isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) _batch_transpose(x) = reshape(x, 1, size(x)...) @@ -31,6 +32,8 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) @@ -47,8 +50,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; end atol = abstol !== nothing ? abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("Broyden currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing + termination_condition = tc(storage) xₙ = x xₙ₋₁ = x @@ -63,14 +75,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; (_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))), _batched_mul(_batch_transpose(Δxₙ), J⁻¹)) - iszero(fₙ) && - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) - - if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) end + xₙ₋₁ = xₙ fₙ₋₁ = fₙ end diff --git a/src/broyden.jl b/src/broyden.jl index d0ae233..e8d339c 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,5 +1,7 @@ """ - Broyden(; batched = false) + Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, reltol = nothing)) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. @@ -9,12 +11,22 @@ and static array problems. To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or `import NNlib` must be present in your code. """ -struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm - Broyden(; batched = false) = new{batched}() +struct Broyden{batched, TC <: NLSolveTerminationCondition} <: + AbstractSimpleNonlinearSolveAlgorithm + termination_condition::TC + + function Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return new{batched, typeof(termination_condition)}(termination_condition) + end end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) @@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; end atol = abstol !== nothing ? abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("Broyden currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing + termination_condition = tc(storage) xₙ = x xₙ₋₁ = x @@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; J⁻¹Δfₙ = J⁻¹ * Δfₙ J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹) - iszero(fₙ) && - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) - - if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) end + xₙ₋₁ = xₙ fₙ₋₁ = fₙ end diff --git a/test/basictests.jl b/test/basictests.jl index 5f8f17a..26e92df 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,8 +1,25 @@ using SimpleNonlinearSolve using StaticArrays using BenchmarkTools +using DiffEqBase using Test +const BATCHED_BROYDEN_SOLVERS = Broyden[] +const BROYDEN_SOLVERS = Broyden[] + +for mode in instances(NLSolveTerminationMode.T) + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition)) + push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition)) +end + # SimpleNewtonRaphson function benchmark_scalar(f, u0) probN = NonlinearProblem{false}(f, u0) @@ -50,16 +67,19 @@ if VERSION >= v"1.7" end # Broyden -function benchmark_scalar(f, u0) +function benchmark_scalar(f, u0, alg) probN = NonlinearProblem{false}(f, u0) - sol = (solve(probN, Broyden())) + sol = (solve(probN, alg)) end -sol = benchmark_scalar(sf, csu0) -@test sol.retcode === ReturnCode.Success -@test sol.u * sol.u - 2 < 1e-9 -if VERSION >= v"1.7" - @test (@ballocated benchmark_scalar(sf, csu0)) == 0 +for alg in BROYDEN_SOLVERS + sol = benchmark_scalar(sf, csu0, alg) + @test sol.retcode === ReturnCode.Success + @test sol.u * sol.u - 2 < 1e-9 + # FIXME: Termination Condition Implementation is allocating. Not sure how to fix it. + # if VERSION >= v"1.7" + # @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0 + # end end # Klement @@ -101,8 +121,8 @@ using ForwardDiff # Immutable f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0] -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), BROYDEN_SOLVERS...) g = function (p) probN = NonlinearProblem{false}(f, csu0, p) sol = solve(probN, alg, abstol = 1e-9) @@ -117,8 +137,8 @@ end # Scalar f, u0 = (u, p) -> u * u - p, 1.0 -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane(), Halley()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), Halley(), BROYDEN_SOLVERS...) g = function (p) probN = NonlinearProblem{false}(f, oftype(p, u0), p) sol = solve(probN, alg) @@ -183,8 +203,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()] @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) end -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane(), Halley()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), Halley(), BROYDEN_SOLVERS...) global g, p g = function (p) probN = NonlinearProblem{false}(f, 0.5, p) @@ -199,14 +219,15 @@ end f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] probN = NonlinearProblem(f, u0) -@test solve(probN, SimpleNewtonRaphson()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleTrustRegion()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] ≈ sqrt(2.0) -@test solve(probN, Broyden()).u[end] ≈ sqrt(2.0) -@test solve(probN, LBroyden()).u[end] ≈ sqrt(2.0) -@test solve(probN, Klement()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleDFSane()).u[end] ≈ sqrt(2.0) +for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false), + SimpleTrustRegion(), + SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(), + BROYDEN_SOLVERS...) + sol = solve(probN, alg) + + @test sol.retcode == ReturnCode.Success + @test sol.u[end] ≈ sqrt(2.0) +end # Separate Error check for Halley; will be included in above error checks for the improved Halley f, u0 = (u, p) -> u * u - 2.0, 1.0 @@ -220,18 +241,16 @@ for u0 in [1.0, [1, 1.0]] probN = NonlinearProblem(f, u0) sol = sqrt(2) * u0 - @test solve(probN, SimpleNewtonRaphson()).u ≈ sol - @test solve(probN, SimpleNewtonRaphson()).u ≈ sol - @test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u ≈ sol - - @test solve(probN, SimpleTrustRegion()).u ≈ sol - @test solve(probN, SimpleTrustRegion()).u ≈ sol - @test solve(probN, SimpleTrustRegion(; autodiff = false)).u ≈ sol + for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false), + SimpleTrustRegion(), + SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), + SimpleDFSane(), + BROYDEN_SOLVERS...) + sol2 = solve(probN, alg) - @test solve(probN, Broyden()).u ≈ sol - @test solve(probN, LBroyden()).u ≈ sol - @test solve(probN, Klement()).u ≈ sol - @test solve(probN, SimpleDFSane()).u ≈ sol + @test sol2.retcode == ReturnCode.Success + @test sol2.u ≈ sol + end end # Bisection Tests @@ -411,3 +430,10 @@ probN = NonlinearProblem{false}(f, u0, p); sol = solve(probN, Broyden(batched = true)) @test abs.(sol.u) ≈ sqrt.(p) + +for alg in BATCHED_BROYDEN_SOLVERS + sol = solve(probN, alg) + + @test sol.retcode == ReturnCode.Success + @test abs.(sol.u) ≈ sqrt.(p) +end