From 2eacfa82159ded06e66782c953be1833c772d49d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Feb 2023 17:45:59 -0500 Subject: [PATCH 1/3] Add Termination Conditions to Broyden --- Project.toml | 6 +- ext/SimpleBatchedNonlinearSolveExt.jl | 27 +++++--- src/broyden.jl | 41 ++++++++---- test/basictests.jl | 90 +++++++++++++++++---------- 4 files changed, 107 insertions(+), 57 deletions(-) diff --git a/Project.toml b/Project.toml index aac1c54..0368644 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.11" +version = "0.1.12" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" @@ -23,7 +23,7 @@ SimpleBatchedNonlinearSolveExt = "NNlib" [compat] ArrayInterfaceCore = "0.1.1" -DiffEqBase = "6.114" +DiffEqBase = "6.118.1" FiniteDiff = "2" ForwardDiff = "0.10.3" NNlib = "0.8" @@ -43,4 +43,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"] +test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib", "DiffEqBase"] diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl index d599704..76c81ef 100644 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -1,6 +1,6 @@ module SimpleBatchedNonlinearSolveExt -using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase +using ArrayInterfaceCore, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) _batch_transpose(x) = reshape(x, 1, size(x)...) @@ -31,6 +31,8 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) @@ -47,8 +49,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; end atol = abstol !== nothing ? abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("Broyden currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing + termination_condition = tc(storage) xₙ = x xₙ₋₁ = x @@ -63,14 +74,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; (_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))), _batched_mul(_batch_transpose(Δxₙ), J⁻¹)) - iszero(fₙ) && - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) - - if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) end + xₙ₋₁ = xₙ fₙ₋₁ = fₙ end diff --git a/src/broyden.jl b/src/broyden.jl index d0ae233..e8d339c 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,5 +1,7 @@ """ - Broyden(; batched = false) + Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, reltol = nothing)) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. @@ -9,12 +11,22 @@ and static array problems. To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or `import NNlib` must be present in your code. """ -struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm - Broyden(; batched = false) = new{batched}() +struct Broyden{batched, TC <: NLSolveTerminationCondition} <: + AbstractSimpleNonlinearSolveAlgorithm + termination_condition::TC + + function Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return new{batched, typeof(termination_condition)}(termination_condition) + end end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) @@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; end atol = abstol !== nothing ? abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("Broyden currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing + termination_condition = tc(storage) xₙ = x xₙ₋₁ = x @@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; J⁻¹Δfₙ = J⁻¹ * Δfₙ J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹) - iszero(fₙ) && - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) - - if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; - retcode = ReturnCode.Success) + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) end + xₙ₋₁ = xₙ fₙ₋₁ = fₙ end diff --git a/test/basictests.jl b/test/basictests.jl index 12525d8..8196212 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,8 +1,25 @@ using SimpleNonlinearSolve using StaticArrays using BenchmarkTools +using DiffEqBase using Test +const BATCHED_BROYDEN_SOLVERS = Broyden[] +const BROYDEN_SOLVERS = Broyden[] + +for mode in instances(NLSolveTerminationMode.T) + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition)) + push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition)) +end + # SimpleNewtonRaphson function benchmark_scalar(f, u0) probN = NonlinearProblem{false}(f, u0) @@ -27,16 +44,19 @@ if VERSION >= v"1.7" end # Broyden -function benchmark_scalar(f, u0) +function benchmark_scalar(f, u0, alg) probN = NonlinearProblem{false}(f, u0) - sol = (solve(probN, Broyden())) + sol = (solve(probN, alg)) end -sol = benchmark_scalar(sf, csu0) -@test sol.retcode === ReturnCode.Success -@test sol.u * sol.u - 2 < 1e-9 -if VERSION >= v"1.7" - @test (@ballocated benchmark_scalar(sf, csu0)) == 0 +for alg in BROYDEN_SOLVERS + sol = benchmark_scalar(sf, csu0, alg) + @test sol.retcode === ReturnCode.Success + @test sol.u * sol.u - 2 < 1e-9 + # FIXME: Termination Condition Implementation is allocating. Not sure how to fix it. + # if VERSION >= v"1.7" + # @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0 + # end end # Klement @@ -78,8 +98,8 @@ using ForwardDiff # Immutable f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0] -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), BROYDEN_SOLVERS...) g = function (p) probN = NonlinearProblem{false}(f, csu0, p) sol = solve(probN, alg, abstol = 1e-9) @@ -94,8 +114,8 @@ end # Scalar f, u0 = (u, p) -> u * u - p, 1.0 -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), BROYDEN_SOLVERS...) g = function (p) probN = NonlinearProblem{false}(f, oftype(p, u0), p) sol = solve(probN, alg) @@ -160,8 +180,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()] @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) end -for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(), - SimpleDFSane()) +for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(), + SimpleDFSane(), BROYDEN_SOLVERS...) global g, p g = function (p) probN = NonlinearProblem{false}(f, 0.5, p) @@ -176,14 +196,15 @@ end f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] probN = NonlinearProblem(f, u0) -@test solve(probN, SimpleNewtonRaphson()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleTrustRegion()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] ≈ sqrt(2.0) -@test solve(probN, Broyden()).u[end] ≈ sqrt(2.0) -@test solve(probN, LBroyden()).u[end] ≈ sqrt(2.0) -@test solve(probN, Klement()).u[end] ≈ sqrt(2.0) -@test solve(probN, SimpleDFSane()).u[end] ≈ sqrt(2.0) +for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false), + SimpleTrustRegion(), + SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(), + BROYDEN_SOLVERS...) + sol = solve(probN, alg) + + @test sol.retcode == ReturnCode.Success + @test sol.u[end] ≈ sqrt(2.0) +end for u0 in [1.0, [1, 1.0]] local f, probN, sol @@ -191,18 +212,16 @@ for u0 in [1.0, [1, 1.0]] probN = NonlinearProblem(f, u0) sol = sqrt(2) * u0 - @test solve(probN, SimpleNewtonRaphson()).u ≈ sol - @test solve(probN, SimpleNewtonRaphson()).u ≈ sol - @test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u ≈ sol - - @test solve(probN, SimpleTrustRegion()).u ≈ sol - @test solve(probN, SimpleTrustRegion()).u ≈ sol - @test solve(probN, SimpleTrustRegion(; autodiff = false)).u ≈ sol + for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false), + SimpleTrustRegion(), + SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), + SimpleDFSane(), + BROYDEN_SOLVERS...) + sol2 = solve(probN, alg) - @test solve(probN, Broyden()).u ≈ sol - @test solve(probN, LBroyden()).u ≈ sol - @test solve(probN, Klement()).u ≈ sol - @test solve(probN, SimpleDFSane()).u ≈ sol + @test sol2.retcode == ReturnCode.Success + @test sol2.u ≈ sol + end end # Bisection Tests @@ -382,3 +401,10 @@ probN = NonlinearProblem{false}(f, u0, p); sol = solve(probN, Broyden(batched = true)) @test abs.(sol.u) ≈ sqrt.(p) + +for alg in BATCHED_BROYDEN_SOLVERS + sol = solve(probN, alg) + + @test sol.retcode == ReturnCode.Success + @test abs.(sol.u) ≈ sqrt.(p) +end From 3a0d9b459a29ded8169e88b59925de2a477ecf68 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 Feb 2023 10:01:24 -0500 Subject: [PATCH 2/3] Update Project.toml --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9ad93b1..796e541 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ SimpleBatchedNonlinearSolveExt = "NNlib" [compat] ArrayInterface = "6, 7" -ArrayInterfaceCore = "0.1.1" FiniteDiff = "2" ForwardDiff = "0.10.3" NNlib = "0.8" From 5a3db2bcac8be7fa7ebfe8d708807316900e0b09 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 Feb 2023 10:10:37 -0500 Subject: [PATCH 3/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 796e541..78f3fcc 100644 --- a/Project.toml +++ b/Project.toml @@ -42,4 +42,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib", "DiffEqBase"] +test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]