From 29917d8388c0e4b86c87afb129b25697dc82bc70 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Dec 2023 16:32:53 -0500 Subject: [PATCH] Add wrapper for FixedPointAcceleration --- ...NonlinearSolveFixedPointAccelerationExt.jl | 55 +++++++++++- ext/NonlinearSolveMINPACKExt.jl | 13 +-- ext/NonlinearSolveNLsolveExt.jl | 16 ++-- ext/NonlinearSolveSpeedMappingExt.jl | 9 +- src/extension_algs.jl | 84 +++++++++++++++++-- test/fixed_point_acceleration.jl | 68 +++++++++++++++ 6 files changed, 220 insertions(+), 25 deletions(-) diff --git a/ext/NonlinearSolveFixedPointAccelerationExt.jl b/ext/NonlinearSolveFixedPointAccelerationExt.jl index 4c0021712..e39946652 100644 --- a/ext/NonlinearSolveFixedPointAccelerationExt.jl +++ b/ext/NonlinearSolveFixedPointAccelerationExt.jl @@ -1,3 +1,56 @@ module NonlinearSolveFixedPointAccelerationExt -end \ No newline at end of file +using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase + +function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...; + abstol = nothing, maxiters = 1000, alias_u0::Bool = false, + show_trace::Val{PrintReports} = Val(false), termination_condition = nothing, + kwargs...) where {PrintReports} + @assert (termination_condition === + nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!" + + u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + u_size = size(u0) + T = eltype(u0) + iip = isinplace(prob) + p = prob.p + + if !iip && prob.u0 isa Number + # FixedPointAcceleration makes the scalar problem into a vector problem + f = (u) -> [prob.f(u[1], p) .+ u[1]] + elseif !iip && prob.u0 isa AbstractVector + f = (u) -> (prob.f(u, p) .+ u) + elseif !iip && prob.u0 isa AbstractArray + f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u) + elseif iip && prob.u0 isa AbstractVector + du = similar(u0) + f = (u) -> (prob.f(du, u, p); du .+ u) + else + du = similar(u0) + f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u) + end + + tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + + sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm, + ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m, + ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening, + PrintReports, ReplaceInvalids = alg.replace_invalids, + ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true) + + res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_ + if res === missing + resid = NonlinearSolve.evaluate_f(prob, u0) + res = u0 + converged = false + else + resid = NonlinearSolve.evaluate_f(prob, res) + converged = maximum(abs, resid) ≤ tol + end + return SciMLBase.build_solution(prob, alg, res, resid; + retcode = converged ? ReturnCode.Success : ReturnCode.Failure, + stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_), + original = sol) +end + +end diff --git a/ext/NonlinearSolveMINPACKExt.jl b/ext/NonlinearSolveMINPACKExt.jl index b86d78199..b6f051ade 100644 --- a/ext/NonlinearSolveMINPACKExt.jl +++ b/ext/NonlinearSolveMINPACKExt.jl @@ -5,7 +5,7 @@ using MINPACK function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...; - abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false, + abstol = nothing, maxiters = 100000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) where {uType, iip} @assert (termination_condition === nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!" @@ -16,6 +16,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end + T = eltype(u0) sizeu = size(prob.u0) p = prob.p @@ -25,11 +26,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, if !iip && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector f! = (du, u) -> prob.f(du, u, p) else # Then it's an in-place function on an abstract array f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0) @@ -43,14 +44,16 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, method = ifelse(alg.method === :auto, ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method) + abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + if SciMLBase.has_jac(prob.f) if !iip && prob.u0 isa Number g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector g! = (du, u) -> prob.f.jac(du, u, p) else # Then it's an in-place function on an abstract array g! = function (du, u) diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index 1b8d7e3f1..4c5223540 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -3,8 +3,9 @@ module NonlinearSolveNLsolveExt using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase import UnPack: @unpack -function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6, - maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; + abstol = nothing, maxiters = 1000, alias_u0::Bool = false, + termination_condition = nothing, kwargs...) @assert (termination_condition === nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!" @@ -14,6 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) end + T = eltype(u0) iip = isinplace(prob) sizeu = size(prob.u0) @@ -25,11 +27,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst if !iip && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector f! = (du, u) -> prob.f(du, u, p) else # Then it's an in-place function on an abstract array f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0) @@ -46,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst if SciMLBase.has_jac(prob.f) if !iip && prob.u0 isa Number g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0)) - elseif !iip && prob.u0 isa Vector{Float64} + elseif !iip && prob.u0 isa AbstractVector g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0)) elseif !iip && prob.u0 isa AbstractArray g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0)) - elseif prob.u0 isa Vector{Float64} + elseif prob.u0 isa AbstractVector g! = (du, u) -> prob.f.jac(du, u, p) else # Then it's an in-place function on an abstract array g! = function (du, u) @@ -68,6 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff) end + abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol + original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) diff --git a/ext/NonlinearSolveSpeedMappingExt.jl b/ext/NonlinearSolveSpeedMappingExt.jl index 444a7e381..be637b124 100644 --- a/ext/NonlinearSolveSpeedMappingExt.jl +++ b/ext/NonlinearSolveSpeedMappingExt.jl @@ -1,7 +1,6 @@ module NonlinearSolveSpeedMappingExt using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase -import UnPack: @unpack function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; abstol = nothing, maxiters = 1000, alias_u0::Bool = false, @@ -20,12 +19,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; iip = isinplace(prob) p = prob.p - if prob.u0 isa Number - resid = [NonlinearSolve.evaluate_f(prob, first(u0))] - else - resid = NonlinearSolve.evaluate_f(prob, u0) - end - if !iip && prob.u0 isa Number m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u)) elseif !iip @@ -46,4 +39,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol) end -end \ No newline at end of file +end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index a076f4173..4fcad43ad 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -19,7 +19,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ -struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearAlgorithm +struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm autodiff::Symbol end @@ -58,7 +58,7 @@ for solving `NonlinearLeastSquaresProblem`. This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ -@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearAlgorithm +@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm autodiff factor factoraccept @@ -128,7 +128,7 @@ then the following methods are allowed: The default choice of `:auto` selects `:hybr` for NonlinearProblem and `:lm` for NonlinearLeastSquaresProblem. """ -struct CMINPACK <: AbstractNonlinearAlgorithm +struct CMINPACK <: AbstractNonlinearSolveAlgorithm show_trace::Bool tracing::Bool method::Symbol @@ -181,7 +181,7 @@ Choices for methods in `NLsolveJL`: these arguments, consult the [NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl). """ -@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm +@concrete struct NLsolveJL <: AbstractNonlinearSolveAlgorithm method::Symbol autodiff::Symbol store_trace::Bool @@ -232,7 +232,7 @@ Fixed Point Problems. We allow using this algorithm to solve root finding proble - N. Lepage-Saucier, Alternating cyclic extrapolation methods for optimization algorithms, arXiv:2104.04974 (2021). https://arxiv.org/abs/2104.04974. """ -@concrete struct SpeedMappingJL <: AbstractNonlinearAlgorithm +@concrete struct SpeedMappingJL <: AbstractNonlinearSolveAlgorithm σ_min stabilize::Bool check_obj::Bool @@ -248,3 +248,77 @@ function SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool return SpeedMappingJL(σ_min, stabilize, check_obj, orders, time_limit) end + +""" + FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, + condition_number_threshold = missing, extrapolation_period = missing, + replace_invalids = :NoAction) + +Wrapper over [FixedPointAcceleration.jl](https://s-baumann.github.io/FixedPointAcceleration.jl/) +for solving Fixed Point Problems. We allow using this algorithm to solve root finding +problems as well. + +## Arguments: + + - `algorithm`: The algorithm to use. Can be `:Anderson`, `:MPE`, `:RRE`, `:VEA`, `:SEA`, + `:Simple`, `:Aitken` or `:Newton`. + - `m`: The number of previous iterates to use for the extrapolation. Only valid for + `:Anderson`. + - `condition_number_threshold`: The condition number threshold for Least Squares Problem. + Only valid for `:Anderson`. + - `extrapolation_period`: The number of iterates between extrapolations. Only valid for + `:MPE`, `:RRE`, `:VEA` and `:SEA`. Defaults to `7` for `:MPE` & `:RRE`, and `6` for + `:SEA` and `:VEA`. For `:SEA` and `:VEA`, this must be a multiple of `2`. + - `replace_invalids`: The method to use for replacing invalid iterates. Can be + `:ReplaceInvalids`, `:ReplaceVector` or `:NoAction`. +""" +@concrete struct FixedPointAccelerationJL <: AbstractNonlinearSolveAlgorithm + algorithm::Symbol + extrapolation_period::Int + replace_invalids::Symbol + dampening + m::Int + condition_number_threshold +end + +function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing, + condition_number_threshold = missing, extrapolation_period = missing, + replace_invalids = :NoAction, dampening = 1.0) + if Base.get_extension(@__MODULE__, :NonlinearSolveFixedPointAccelerationExt) === nothing + error("FixedPointAccelerationJL requires FixedPointAcceleration.jl to be loaded") + end + + @assert algorithm in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @assert replace_invalids in (:ReplaceInvalids, :ReplaceVector, :NoAction) + + if algorithm !== :Anderson + if condition_number_threshold !== missing + error("`condition_number_threshold` is only valid for Anderson acceleration") + end + if m !== missing + error("`m` is only valid for Anderson acceleration") + end + end + condition_number_threshold === missing && (condition_number_threshold = 1e3) + m === missing && (m = 10) + + if algorithm !== :MPE && algorithm !== :RRE && algorithm !== :VEA && algorithm !== :SEA + if extrapolation_period !== missing + error("`extrapolation_period` is only valid for MPE, RRE, VEA and SEA") + end + end + if extrapolation_period === missing + if algorithm === :SEA || algorithm === :VEA + extrapolation_period = 6 + else + extrapolation_period = 7 + end + else + if (algorithm === :SEA || algorithm === :VEA) && extrapolation_period % 2 != 0 + error("`extrapolation_period` must be multiples of 2 for SEA and VEA") + end + end + + return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids, + dampening, m, condition_number_threshold) +end diff --git a/test/fixed_point_acceleration.jl b/test/fixed_point_acceleration.jl index e69de29bb..099ccff9e 100644 --- a/test/fixed_point_acceleration.jl +++ b/test/fixed_point_acceleration.jl @@ -0,0 +1,68 @@ +using NonlinearSolve, FixedPointAcceleration, LinearAlgebra, Test + +# Simple Scalar Problem +@testset "Simple Scalar Problem" begin + f1(x, p) = cos(x) - x + prob = NonlinearProblem(f1, 1.1) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @test abs(solve(prob, FixedPointAccelerationJL()).resid) ≤ 1e-10 + end +end + +# Simple Vector Problem +@testset "Simple Vector Problem" begin + f2(x, p) = cos.(x) .- x + prob = NonlinearProblem(f2, [1.1, 1.1]) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + @test maximum(abs.(solve(prob, FixedPointAccelerationJL()).resid)) ≤ 1e-10 + end +end + +# Fixed Point for Power Method +# Taken from https://github.com/NicolasL-S/SpeedMapping.jl/blob/95951db8f8a4457093090e18802ad382db1c76da/test/runtests.jl +@testset "Power Method" begin + C = [1 2 3; 4 5 6; 7 8 9] + A = C + C' + B = Hermitian(ones(10) * ones(10)' .* im + Diagonal(1:10)) + + function power_method!(du, u, A) + mul!(du, A, u) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method!, ones(3), A) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg)) + if SciMLBase.successful_retcode(sol) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + else + @warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)" + @test_broken sol.u' * A[:, 3] ≈ 32.916472867168096 + end + end + + # Non vector inputs + function power_method_nonvec!(du, u, A) + mul!(vec(du), A, vec(u)) + du ./= norm(du, Inf) + du .-= u # Convert to a root finding problem + return nothing + end + + prob = NonlinearProblem(power_method_nonvec!, ones(1, 3, 1), A) + + for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton) + sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg)) + if SciMLBase.successful_retcode(sol) + @test sol.u' * A[:, 3] ≈ 32.916472867168096 + else + @warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)" + @test_broken sol.u' * A[:, 3] ≈ 32.916472867168096 + end + end +end