From 89855a444c48196c53c73913304436f0d4ceffbf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 15:12:28 -0500 Subject: [PATCH 1/5] Add wrapper for SpeedMapping --- Project.toml | 12 ++++- docs/pages.jl | 1 + docs/src/api/fixedpointacceleration.md | 17 +++++++ docs/src/api/speedmapping.md | 17 +++++++ docs/src/solvers/FixedPointSolvers.md | 42 ++++++++++++++++ docs/src/solvers/NonlinearSystemSolvers.md | 2 +- ...NonlinearSolveFixedPointAccelerationExt.jl | 3 ++ ext/NonlinearSolveSpeedMappingExt.jl | 49 +++++++++++++++++++ src/NonlinearSolve.jl | 3 +- src/extension_algs.jl | 46 ++++++++++++++++- test/fixed_point_acceleration.jl | 0 test/runtests.jl | 2 + test/speedmapping.jl | 27 ++++++++++ 13 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 docs/src/api/fixedpointacceleration.md create mode 100644 docs/src/api/speedmapping.md create mode 100644 docs/src/solvers/FixedPointSolvers.md create mode 100644 ext/NonlinearSolveFixedPointAccelerationExt.jl create mode 100644 ext/NonlinearSolveSpeedMappingExt.jl create mode 100644 test/fixed_point_acceleration.jl create mode 100644 test/speedmapping.jl diff --git a/Project.toml b/Project.toml index 6a1f5835c..00c636d0e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.1.2" +version = "3.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -32,18 +32,22 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" +FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] NonlinearSolveBandedMatricesExt = "BandedMatrices" NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" +NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" NonlinearSolveMINPACKExt = "MINPACK" NonlinearSolveNLsolveExt = "NLsolve" +NonlinearSolveSpeedMappingExt = "SpeedMapping" NonlinearSolveSymbolicsExt = "Symbolics" NonlinearSolveZygoteExt = "Zygote" @@ -60,6 +64,7 @@ Enzyme = "0.11.11" FastBroadcast = "0.2.8" FastLevenbergMarquardt = "0.1" FiniteDiff = "2.21" +FixedPointAcceleration = "0.3" ForwardDiff = "0.10.36" LazyArrays = "1.8.2" LeastSquaresOptim = "0.8.5" @@ -84,6 +89,7 @@ SciMLOperators = "0.3.7" SimpleNonlinearSolve = "1.0.2" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" +SpeedMapping = "0.3" StableRNGs = "1" StaticArrays = "1.7" Symbolics = "5.13" @@ -99,6 +105,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" +FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -112,6 +119,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -119,4 +127,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq"] +test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"] diff --git a/docs/pages.jl b/docs/pages.jl index a8107bea2..d17ea58a6 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -20,6 +20,7 @@ pages = ["index.md", "solvers/BracketingSolvers.md", "solvers/SteadyStateSolvers.md", "solvers/NonlinearLeastSquaresSolvers.md", + "solvers/FixedPointSolvers.md", "solvers/LineSearch.md"], "Detailed Solver APIs" => Any["api/nonlinearsolve.md", "api/simplenonlinearsolve.md", diff --git a/docs/src/api/fixedpointacceleration.md b/docs/src/api/fixedpointacceleration.md new file mode 100644 index 000000000..38e6ae0b5 --- /dev/null +++ b/docs/src/api/fixedpointacceleration.md @@ -0,0 +1,17 @@ +# FixedPointAcceleration.jl + +This is a extension for importing solvers from FixedPointAcceleration.jl into the SciML +interface. Note that these solvers do not come by default, and thus one needs to install +the package before using these solvers: + +```julia +using Pkg +Pkg.add("FixedPointAcceleration") +using FixedPointAcceleration, NonlinearSolve +``` + +## Solver API + +```@docs +FixedPointAccelerationJL +``` diff --git a/docs/src/api/speedmapping.md b/docs/src/api/speedmapping.md new file mode 100644 index 000000000..05a1931fe --- /dev/null +++ b/docs/src/api/speedmapping.md @@ -0,0 +1,17 @@ +# SppedMapping.jl + +This is a extension for importing solvers from SppedMapping.jl into the SciML +interface. Note that these solvers do not come by default, and thus one needs to install +the package before using these solvers: + +```julia +using Pkg +Pkg.add("SppedMapping") +using SppedMapping, NonlinearSolve +``` + +## Solver API + +```@docs +SppedMappingJL +``` diff --git a/docs/src/solvers/FixedPointSolvers.md b/docs/src/solvers/FixedPointSolvers.md new file mode 100644 index 000000000..b94dd3a32 --- /dev/null +++ b/docs/src/solvers/FixedPointSolvers.md @@ -0,0 +1,42 @@ +# Fixed Point Solvers + +Currently we don't have an API to directly specify Fixed Point Solvers. However, a Fixed +Point Problem can be triviall converted to a Root Finding Problem. Say we want to solve: + +```math +f(u) = u +``` + +This can be written as: + +```math +g(u) = f(u) - u = 0 +``` + +Where ``g(u) = 0`` is a root finding problem. Note that we can use any root finding +algorithm to solve this problem. However, this is often not the most efficient way to +solve a fixed point problem. We provide a few algorithms available via extensions that +are more efficient for fixed point problems. + +Note that even if you use one of the Fixed Point Solvers mentioned here, you must still +use the `NonlinearProblem` API to specify the problem, i.e., ``g(u) = 0``. + +## Recommended Methods + +Using [native NonlinearSolve.jl methods](@ref nonlinearsystemsolvers) is the recommended +approach. For systems where constructing Jacobian Matrices are expensive, we recommend +using a Krylov Method with one of those solvers. + +## Full List of Methods + +We are only listing the methods that natively solve fixed point problems. + +### SpeedMapping.jl + + - `SpeedMappingJL()`: accelerates the convergence of a mapping to a fixed point by the + Alternating cyclic extrapolation algorithm (ACX). + +### FixedPointAcceleration.jl + + - `FixedPointAccelerationJL()`: accelerates the convergence of a mapping to a fixed point + by the Anderson acceleration algorithm and a few other methods. diff --git a/docs/src/solvers/NonlinearSystemSolvers.md b/docs/src/solvers/NonlinearSystemSolvers.md index 776e58b9c..bba941d86 100644 --- a/docs/src/solvers/NonlinearSystemSolvers.md +++ b/docs/src/solvers/NonlinearSystemSolvers.md @@ -1,6 +1,6 @@ # [Nonlinear System Solvers](@id nonlinearsystemsolvers) -`solve(prob::NonlinearProblem,alg;kwargs)` +`solve(prob::NonlinearProblem, alg; kwargs)` Solves for ``f(u)=0`` in the problem defined by `prob` using the algorithm `alg`. If no algorithm is given, a default algorithm will be chosen. diff --git a/ext/NonlinearSolveFixedPointAccelerationExt.jl b/ext/NonlinearSolveFixedPointAccelerationExt.jl new file mode 100644 index 000000000..4c0021712 --- /dev/null +++ b/ext/NonlinearSolveFixedPointAccelerationExt.jl @@ -0,0 +1,3 @@ +module NonlinearSolveFixedPointAccelerationExt + +end \ No newline at end of file diff --git a/ext/NonlinearSolveSpeedMappingExt.jl b/ext/NonlinearSolveSpeedMappingExt.jl new file mode 100644 index 000000000..444a7e381 --- /dev/null +++ b/ext/NonlinearSolveSpeedMappingExt.jl @@ -0,0 +1,49 @@ +module NonlinearSolveSpeedMappingExt + +using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase +import UnPack: @unpack + +function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; + abstol = nothing, maxiters = 1000, alias_u0::Bool = false, + store_trace::Val{store_info} = Val(false), termination_condition = nothing, + kwargs...) where {store_info} + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!" + + if typeof(prob.u0) <: Number + u0 = [prob.u0] + else + u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + end + + T = eltype(u0) + iip = isinplace(prob) + p = prob.p + + if prob.u0 isa Number + resid = [NonlinearSolve.evaluate_f(prob, first(u0))] + else + resid = NonlinearSolve.evaluate_f(prob, u0) + end + + if !iip && prob.u0 isa Number + m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u)) + elseif !iip + m! = (du, u) -> (du .= prob.f(u, p) .+ u) + else + m! = (du, u) -> (prob.f(du, u, p); du .+= u) + end + + tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + + sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders, + alg.check_obj, store_info, alg.σ_min, alg.stabilize) + res = prob.u0 isa Number ? first(sol.minimizer) : sol.minimizer + resid = NonlinearSolve.evaluate_f(prob, sol.minimizer) + + return SciMLBase.build_solution(prob, alg, res, resid; + retcode = sol.converged ? ReturnCode.Success : ReturnCode.Failure, + stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol) +end + +end \ No newline at end of file diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index dd0a6cc33..4c602991c 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -236,7 +236,8 @@ export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient, Broyden, Klement, LimitedMemoryBroyden -export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL +export LeastSquaresOptimJL, + FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg diff --git a/src/extension_algs.jl b/src/extension_algs.jl index b06414f0f..a076f4173 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -19,7 +19,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ -struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm +struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearAlgorithm autodiff::Symbol end @@ -58,7 +58,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ -@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm +@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearAlgorithm autodiff factor factoraccept @@ -206,3 +206,45 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) end + +""" + SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false, + orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000) + +Wrapper over [SpeedMapping.jl](https://nicolasl-s.github.io/SpeedMapping.jl) for solving +Fixed Point Problems. We allow using this algorithm to solve root finding problems as well. + +## Arguments: + + - `σ_min`: Setting to `1` may avoid stalling (see paper). + - `stabilize`: performs a stabilization mapping before extrapolating. Setting to `true` + may improve the performance for applications like accelerating the EM or MM algorithms + (see paper). + - `check_obj`: In case of NaN or Inf values, the algorithm restarts at the best past + iterate. + - `orders`: determines ACX's alternating order. Must be between `1` and `3` (where `1` + means no extrapolation). The two recommended orders are `[3, 2]` and `[3, 3, 2]`, the + latter being potentially better for highly non-linear applications (see paper). + - `time_limit`: time limit for the algorithm. + +## References: + + - N. Lepage-Saucier, Alternating cyclic extrapolation methods for optimization algorithms, + arXiv:2104.04974 (2021). https://arxiv.org/abs/2104.04974. +""" +@concrete struct SpeedMappingJL <: AbstractNonlinearAlgorithm + σ_min + stabilize::Bool + check_obj::Bool + orders::Vector{Int} + time_limit +end + +function SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false, + orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000) + if Base.get_extension(@__MODULE__, :NonlinearSolveSpeedMappingExt) === nothing + error("SpeedMappingJL requires SpeedMapping.jl to be loaded") + end + + return SpeedMappingJL(σ_min, stabilize, check_obj, orders, time_limit) +end diff --git a/test/fixed_point_acceleration.jl b/test/fixed_point_acceleration.jl new file mode 100644 index 000000000..e69de29bb diff --git a/test/runtests.jl b/test/runtests.jl index 2e74e905c..f761b08ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,8 @@ end if GROUP == "All" || GROUP == "Wrappers" @time @safetestset "MINPACK" include("minpack.jl") @time @safetestset "NLsolve" include("nlsolve.jl") + @time @safetestset "SpeedMapping" include("speedmapping.jl") + @time @safetestset "FixedPointAcceleration" include("fixed_point_acceleration.jl") end if GROUP == "All" || GROUP == "23TestProblems" diff --git a/test/speedmapping.jl b/test/speedmapping.jl new file mode 100644 index 000000000..2794e4a29 --- /dev/null +++ b/test/speedmapping.jl @@ -0,0 +1,27 @@ +using NonlinearSolve, SpeedMapping, LinearAlgebra, Test + +# Fixed Point for Power Method +# Taken from https://github.com/NicolasL-S/SpeedMapping.jl/blob/95951db8f8a4457093090e18802ad382db1c76da/test/runtests.jl +@testset "Power Method" begin + C = [1 2 3; 4 5 6; 7 8 9] + A = C + C' + B = Hermitian(ones(10) * ones(10)' .* im + Diagonal(1:10)) + + function power_method!(du, u, A) + mul!(du, A, u) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method!, ones(3), A) + + sol = solve(prob, SpeedMappingJL()) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + + sol = solve(prob, SpeedMappingJL(; orders = [3, 2])) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + + sol = solve(prob, SpeedMappingJL(; stabilize = true)) + @test sol.u' * A[:, 3] ≈ 32.91647286145264 +end From b7ba2c7eb4530a1f093dd0e81d57986806a7453e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 15:22:23 -0500 Subject: [PATCH 2/5] Update docs/src/solvers/FixedPointSolvers.md Co-authored-by: Christopher Rackauckas --- docs/src/solvers/FixedPointSolvers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/solvers/FixedPointSolvers.md b/docs/src/solvers/FixedPointSolvers.md index b94dd3a32..e247e7c15 100644 --- a/docs/src/solvers/FixedPointSolvers.md +++ b/docs/src/solvers/FixedPointSolvers.md @@ -1,7 +1,7 @@ # Fixed Point Solvers Currently we don't have an API to directly specify Fixed Point Solvers. However, a Fixed -Point Problem can be triviall converted to a Root Finding Problem. Say we want to solve: +Point Problem can be trivially converted to a Root Finding Problem. Say we want to solve: ```math f(u) = u From 6604ea6f98b1a371263d5d73a04d764115beaa29 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 15:23:34 -0500 Subject: [PATCH 3/5] Fix typo --- docs/src/api/speedmapping.md | 10 +++++----- docs/src/solvers/FixedPointSolvers.md | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/api/speedmapping.md b/docs/src/api/speedmapping.md index 05a1931fe..76b13ba18 100644 --- a/docs/src/api/speedmapping.md +++ b/docs/src/api/speedmapping.md @@ -1,17 +1,17 @@ -# SppedMapping.jl +# SpeedMapping.jl -This is a extension for importing solvers from SppedMapping.jl into the SciML +This is a extension for importing solvers from SpeedMapping.jl into the SciML interface. Note that these solvers do not come by default, and thus one needs to install the package before using these solvers: ```julia using Pkg -Pkg.add("SppedMapping") -using SppedMapping, NonlinearSolve +Pkg.add("SpeedMapping") +using SpeedMapping, NonlinearSolve ``` ## Solver API ```@docs -SppedMappingJL +SpeedMappingJL ``` diff --git a/docs/src/solvers/FixedPointSolvers.md b/docs/src/solvers/FixedPointSolvers.md index e247e7c15..4eb3fdf0e 100644 --- a/docs/src/solvers/FixedPointSolvers.md +++ b/docs/src/solvers/FixedPointSolvers.md @@ -13,7 +13,7 @@ This can be written as: g(u) = f(u) - u = 0 ``` -Where ``g(u) = 0`` is a root finding problem. Note that we can use any root finding +``g(u) = 0`` is a root finding problem. Note that we can use any root finding algorithm to solve this problem. However, this is often not the most efficient way to solve a fixed point problem. We provide a few algorithms available via extensions that are more efficient for fixed point problems. From 831dacc9bd908180d2354318b71905c2eba3939e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 15:26:18 -0500 Subject: [PATCH 4/5] Test non-vector inputs --- test/speedmapping.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/speedmapping.jl b/test/speedmapping.jl index 2794e4a29..97bfed4ef 100644 --- a/test/speedmapping.jl +++ b/test/speedmapping.jl @@ -24,4 +24,23 @@ using NonlinearSolve, SpeedMapping, LinearAlgebra, Test sol = solve(prob, SpeedMappingJL(; stabilize = true)) @test sol.u' * A[:, 3] ≈ 32.91647286145264 + + # Non vector inputs + function power_method_nonvec!(du, u, A) + mul!(vec(du), A, vec(u)) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method_nonvec!, ones(1, 3, 1), A) + + sol = solve(prob, SpeedMappingJL()) + @test vec(sol.u)' * A[:, 3] ≈ 32.916472867168096 + + sol = solve(prob, SpeedMappingJL(; orders = [3, 2])) + @test vec(sol.u)' * A[:, 3] ≈ 32.916472867168096 + + sol = solve(prob, SpeedMappingJL(; stabilize = true)) + @test vec(sol.u)' * A[:, 3] ≈ 32.91647286145264 end From 29917d8388c0e4b86c87afb129b25697dc82bc70 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 16:32:53 -0500 Subject: [PATCH 5/5] Add wrapper for FixedPointAcceleration --- ...NonlinearSolveFixedPointAccelerationExt.jl | 55 +++++++++++- ext/NonlinearSolveMINPACKExt.jl | 13 +-- ext/NonlinearSolveNLsolveExt.jl | 16 ++-- ext/NonlinearSolveSpeedMappingExt.jl | 9 +- src/extension_algs.jl | 84 +++++++++++++++++-- test/fixed_point_acceleration.jl | 68 +++++++++++++++ 6 files changed, 220 insertions(+), 25 deletions(-) diff --git a/ext/NonlinearSolveFixedPointAccelerationExt.jl b/ext/NonlinearSolveFixedPointAccelerationExt.jl index 4c0021712..e39946652 100644 --- a/ext/NonlinearSolveFixedPointAccelerationExt.jl +++ b/ext/NonlinearSolveFixedPointAccelerationExt.jl @@ -1,3 +1,56 @@ module NonlinearSolveFixedPointAccelerationExt -end \ No newline at end of file +using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase + +function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...; + abstol = nothing, maxiters = 1000, alias_u0::Bool = false, + show_trace::Val{PrintReports} = Val(false), termination_condition = nothing, + kwargs...) where {PrintReports} + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!" + + u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + u_size = size(u0) + T = eltype(u0) + iip = isinplace(prob) + p = prob.p + + if !iip && prob.u0 isa Number + # FixedPointAcceleration makes the scalar problem into a vector problem + f = (u) -> [prob.f(u[1], p) .+ u[1]] + elseif !iip && prob.u0 isa AbstractVector + f = (u) -> (prob.f(u, p) .+ u) + elseif !iip && prob.u0 isa AbstractArray + f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u) + elseif iip && prob.u0 isa AbstractVector + du = similar(u0) + f = (u) -> (prob.f(du, u, p); du .+ u) + else + du = similar(u0) + f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u) + end + + tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + + sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm, + ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m, + ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening, + PrintReports, ReplaceInvalids = alg.replace_invalids, + ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true) + + res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_ + if res === missing + resid = NonlinearSolve.evaluate_f(prob, u0) + res = u0 + converged = false + else + resid = NonlinearSolve.evaluate_f(prob, res) + converged = maximum(abs, resid) ≤ tol + end + return SciMLBase.build_solution(prob, alg, res, resid; + retcode = converged ? ReturnCode.Success : ReturnCode.Failure, + stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_), + original = sol) +end + +end diff --git a/ext/NonlinearSolveMINPACKExt.jl b/ext/NonlinearSolveMINPACKExt.jl index b86d78199..b6f051ade 100644 --- a/ext/NonlinearSolveMINPACKExt.jl +++ b/ext/NonlinearSolveMINPACKExt.jl @@ -5,7 +5,7 @@ using MINPACK function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...; - abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false, + abstol = nothing, maxiters = 100000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) where {uType, iip} @assert (termination_condition === nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!" @@ -16,6 +16,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end + T = eltype(u0) sizeu = size(prob.u0) p = prob.p @@ -25,11 +26,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, if !iip && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector f! = (du, u) -> prob.f(du, u, p) else # Then it's an in-place function on an abstract array f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0) @@ -43,14 +44,16 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, method = ifelse(alg.method === :auto, ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method) + abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + if SciMLBase.has_jac(prob.f) if !iip && prob.u0 isa Number g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector g! = (du, u) -> prob.f.jac(du, u, p) else # Then it's an in-place function on an abstract array g! = function (du, u) diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index 1b8d7e3f1..4c5223540 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -3,8 +3,9 @@ module NonlinearSolveNLsolveExt using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase import UnPack: @unpack -function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6, - maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; + abstol = nothing, maxiters = 1000, alias_u0::Bool = false, + termination_condition = nothing, kwargs...) @assert (termination_condition === nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!" @@ -14,6 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end + T = eltype(u0) iip = isinplace(prob) sizeu = size(prob.u0) @@ -25,11 +27,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst if !iip && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector f! = (du, u) -> prob.f(du, u, p) else # Then it's an in-place function on an abstract array f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0) @@ -46,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst if SciMLBase.has_jac(prob.f) if !iip && prob.u0 isa Number g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector g! = (du, u) -> prob.f.jac(du, u, p) else # Then it's an in-place function on an abstract array g! = function (du, u) @@ -68,6 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff) end + abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) diff --git a/ext/NonlinearSolveSpeedMappingExt.jl b/ext/NonlinearSolveSpeedMappingExt.jl index 444a7e381..be637b124 100644 --- a/ext/NonlinearSolveSpeedMappingExt.jl +++ b/ext/NonlinearSolveSpeedMappingExt.jl @@ -1,7 +1,6 @@ module NonlinearSolveSpeedMappingExt using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase -import UnPack: @unpack function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; abstol = nothing, maxiters = 1000, alias_u0::Bool = false, @@ -20,12 +19,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; iip = isinplace(prob) p = prob.p - if prob.u0 isa Number - resid = [NonlinearSolve.evaluate_f(prob, first(u0))] - else - resid = NonlinearSolve.evaluate_f(prob, u0) - end - if !iip && prob.u0 isa Number m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u)) elseif !iip @@ -46,4 +39,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol) end -end \ No newline at end of file +end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index a076f4173..4fcad43ad 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -19,7 +19,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ -struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearAlgorithm +struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm autodiff::Symbol end @@ -58,7 +58,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ -@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearAlgorithm +@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm autodiff factor factoraccept @@ -128,7 +128,7 @@ then the following methods are allowed: The default choice of `:auto` selects `:hybr` for NonlinearProblem and `:lm` for NonlinearLeastSquaresProblem. """ -struct CMINPACK <: AbstractNonlinearAlgorithm +struct CMINPACK <: AbstractNonlinearSolveAlgorithm show_trace::Bool tracing::Bool method::Symbol @@ -181,7 +181,7 @@ Choices for methods in `NLsolveJL`: these arguments, consult the [NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl). """ -@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm +@concrete struct NLsolveJL <: AbstractNonlinearSolveAlgorithm method::Symbol autodiff::Symbol store_trace::Bool @@ -232,7 +232,7 @@ Fixed Point Problems. We allow using this algorithm to solve root finding proble - N. Lepage-Saucier, Alternating cyclic extrapolation methods for optimization algorithms, arXiv:2104.04974 (2021). https://arxiv.org/abs/2104.04974. """ -@concrete struct SpeedMappingJL <: AbstractNonlinearAlgorithm +@concrete struct SpeedMappingJL <: AbstractNonlinearSolveAlgorithm σ_min stabilize::Bool check_obj::Bool @@ -248,3 +248,77 @@ function SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool return SpeedMappingJL(σ_min, stabilize, check_obj, orders, time_limit) end + +""" + FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, + condition_number_threshold = missing, extrapolation_period = missing, + replace_invalids = :NoAction) + +Wrapper over [FixedPointAcceleration.jl](https://s-baumann.github.io/FixedPointAcceleration.jl/) +for solving Fixed Point Problems. We allow using this algorithm to solve root finding +problems as well. + +## Arguments: + + - `algorithm`: The algorithm to use. Can be `:Anderson`, `:MPE`, `:RRE`, `:VEA`, `:SEA`, + `:Simple`, `:Aitken` or `:Newton`. + - `m`: The number of previous iterates to use for the extrapolation. Only valid for + `:Anderson`. + - `condition_number_threshold`: The condition number threshold for Least Squares Problem. + Only valid for `:Anderson`. + - `extrapolation_period`: The number of iterates between extrapolations. Only valid for + `:MPE`, `:RRE`, `:VEA` and `:SEA`. Defaults to `7` for `:MPE` & `:RRE`, and `6` for + `:SEA` and `:VEA`. For `:SEA` and `:VEA`, this must be a multiple of `2`. + - `replace_invalids`: The method to use for replacing invalid iterates. Can be + `:ReplaceInvalids`, `:ReplaceVector` or `:NoAction`. +""" +@concrete struct FixedPointAccelerationJL <: AbstractNonlinearSolveAlgorithm + algorithm::Symbol + extrapolation_period::Int + replace_invalids::Symbol + dampening + m::Int + condition_number_threshold +end + +function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, + condition_number_threshold = missing, extrapolation_period = missing, + replace_invalids = :NoAction, dampening = 1.0) + if Base.get_extension(@__MODULE__, :NonlinearSolveFixedPointAccelerationExt) === nothing + error("FixedPointAccelerationJL requires FixedPointAcceleration.jl to be loaded") + end + + @assert algorithm in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @assert replace_invalids in (:ReplaceInvalids, :ReplaceVector, :NoAction) + + if algorithm !== :Anderson + if condition_number_threshold !== missing + error("`condition_number_threshold` is only valid for Anderson acceleration") + end + if m !== missing + error("`m` is only valid for Anderson acceleration") + end + end + condition_number_threshold === missing && (condition_number_threshold = 1e3) + m === missing && (m = 10) + + if algorithm !== :MPE && algorithm !== :RRE && algorithm !== :VEA && algorithm !== :SEA + if extrapolation_period !== missing + error("`extrapolation_period` is only valid for MPE, RRE, VEA and SEA") + end + end + if extrapolation_period === missing + if algorithm === :SEA || algorithm === :VEA + extrapolation_period = 6 + else + extrapolation_period = 7 + end + else + if (algorithm === :SEA || algorithm === :VEA) && extrapolation_period % 2 != 0 + error("`extrapolation_period` must be multiples of 2 for SEA and VEA") + end + end + + return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids, + dampening, m, condition_number_threshold) +end diff --git a/test/fixed_point_acceleration.jl b/test/fixed_point_acceleration.jl index e69de29bb..099ccff9e 100644 --- a/test/fixed_point_acceleration.jl +++ b/test/fixed_point_acceleration.jl @@ -0,0 +1,68 @@ +using NonlinearSolve, FixedPointAcceleration, LinearAlgebra, Test + +# Simple Scalar Problem +@testset "Simple Scalar Problem" begin + f1(x, p) = cos(x) - x + prob = NonlinearProblem(f1, 1.1) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @test abs(solve(prob, FixedPointAccelerationJL()).resid) ≤ 1e-10 + end +end + +# Simple Vector Problem +@testset "Simple Vector Problem" begin + f2(x, p) = cos.(x) .- x + prob = NonlinearProblem(f2, [1.1, 1.1]) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @test maximum(abs.(solve(prob, FixedPointAccelerationJL()).resid)) ≤ 1e-10 + end +end + +# Fixed Point for Power Method +# Taken from https://github.com/NicolasL-S/SpeedMapping.jl/blob/95951db8f8a4457093090e18802ad382db1c76da/test/runtests.jl +@testset "Power Method" begin + C = [1 2 3; 4 5 6; 7 8 9] + A = C + C' + B = Hermitian(ones(10) * ones(10)' .* im + Diagonal(1:10)) + + function power_method!(du, u, A) + mul!(du, A, u) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method!, ones(3), A) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg)) + if SciMLBase.successful_retcode(sol) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + else + @warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)" + @test_broken sol.u' * A[:, 3] ≈ 32.916472867168096 + end + end + + # Non vector inputs + function power_method_nonvec!(du, u, A) + mul!(vec(du), A, vec(u)) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method_nonvec!, ones(1, 3, 1), A) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg)) + if SciMLBase.successful_retcode(sol) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + else + @warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)" + @test_broken sol.u' * A[:, 3] ≈ 32.916472867168096 + end + end +end