Skip to content

Commit

Permalink
Add wrapper for FixedPointAcceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2023
1 parent 831dacc commit 29917d8
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 25 deletions.
55 changes: 54 additions & 1 deletion ext/NonlinearSolveFixedPointAccelerationExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,56 @@
module NonlinearSolveFixedPointAccelerationExt

end
using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase

function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
kwargs...) where {PrintReports}
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"

u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
u_size = size(u0)
T = eltype(u0)
iip = isinplace(prob)
p = prob.p

if !iip && prob.u0 isa Number
# FixedPointAcceleration makes the scalar problem into a vector problem
f = (u) -> [prob.f(u[1], p) .+ u[1]]
elseif !iip && prob.u0 isa AbstractVector
f = (u) -> (prob.f(u, p) .+ u)
elseif !iip && prob.u0 isa AbstractArray
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u)

Check warning on line 24 in ext/NonlinearSolveFixedPointAccelerationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFixedPointAccelerationExt.jl#L24

Added line #L24 was not covered by tests
elseif iip && prob.u0 isa AbstractVector
du = similar(u0)
f = (u) -> (prob.f(du, u, p); du .+ u)
else
du = similar(u0)
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
end

tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol

sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
PrintReports, ReplaceInvalids = alg.replace_invalids,
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)

res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_
if res === missing
resid = NonlinearSolve.evaluate_f(prob, u0)
res = u0
converged = false
else
resid = NonlinearSolve.evaluate_f(prob, res)
converged = maximum(abs, resid) tol
end
return SciMLBase.build_solution(prob, alg, res, resid;
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_),
original = sol)
end

end
13 changes: 8 additions & 5 deletions ext/NonlinearSolveMINPACKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using MINPACK

function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
abstol = nothing, maxiters = 100000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...) where {uType, iip}
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"
Expand All @@ -16,6 +16,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

T = eltype(u0)
sizeu = size(prob.u0)
p = prob.p

Expand All @@ -25,11 +26,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},

if !iip && prob.u0 isa Number
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
elseif !iip && prob.u0 isa Vector{Float64}
elseif !iip && prob.u0 isa AbstractVector
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
elseif !iip && prob.u0 isa AbstractArray
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
elseif prob.u0 isa Vector{Float64}
elseif prob.u0 isa AbstractVector
f! = (du, u) -> prob.f(du, u, p)
else # Then it's an in-place function on an abstract array
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
Expand All @@ -43,14 +44,16 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
method = ifelse(alg.method === :auto,
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)

abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol

if SciMLBase.has_jac(prob.f)
if !iip && prob.u0 isa Number
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
elseif !iip && prob.u0 isa Vector{Float64}
elseif !iip && prob.u0 isa AbstractVector
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
elseif !iip && prob.u0 isa AbstractArray
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
elseif prob.u0 isa Vector{Float64}
elseif prob.u0 isa AbstractVector
g! = (du, u) -> prob.f.jac(du, u, p)
else # Then it's an in-place function on an abstract array
g! = function (du, u)
Expand Down
16 changes: 10 additions & 6 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module NonlinearSolveNLsolveExt
using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
import UnPack: @unpack

function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...)
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"

Expand All @@ -14,6 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

T = eltype(u0)
iip = isinplace(prob)

sizeu = size(prob.u0)
Expand All @@ -25,11 +27,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst

if !iip && prob.u0 isa Number
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
elseif !iip && prob.u0 isa Vector{Float64}
elseif !iip && prob.u0 isa AbstractVector
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
elseif !iip && prob.u0 isa AbstractArray
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
elseif prob.u0 isa Vector{Float64}
elseif prob.u0 isa AbstractVector
f! = (du, u) -> prob.f(du, u, p)
else # Then it's an in-place function on an abstract array
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
Expand All @@ -46,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
if SciMLBase.has_jac(prob.f)
if !iip && prob.u0 isa Number
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
elseif !iip && prob.u0 isa Vector{Float64}
elseif !iip && prob.u0 isa AbstractVector
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
elseif !iip && prob.u0 isa AbstractArray
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
elseif prob.u0 isa Vector{Float64}
elseif prob.u0 isa AbstractVector
g! = (du, u) -> prob.f.jac(du, u, p)
else # Then it's an in-place function on an abstract array
g! = function (du, u)
Expand All @@ -68,6 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
end

abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol

original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
show_trace)
Expand Down
9 changes: 1 addition & 8 deletions ext/NonlinearSolveSpeedMappingExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module NonlinearSolveSpeedMappingExt

using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase
import UnPack: @unpack

function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
Expand All @@ -20,12 +19,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
iip = isinplace(prob)
p = prob.p

if prob.u0 isa Number
resid = [NonlinearSolve.evaluate_f(prob, first(u0))]
else
resid = NonlinearSolve.evaluate_f(prob, u0)
end

if !iip && prob.u0 isa Number
m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u))

Check warning on line 23 in ext/NonlinearSolveSpeedMappingExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSpeedMappingExt.jl#L23

Added line #L23 was not covered by tests
elseif !iip
Expand All @@ -46,4 +39,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol)
end

end
end
84 changes: 79 additions & 5 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ for solving `NonlinearLeastSquaresProblem`.
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearAlgorithm
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol
end

Expand Down Expand Up @@ -58,7 +58,7 @@ for solving `NonlinearLeastSquaresProblem`.
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearAlgorithm
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff
factor
factoraccept
Expand Down Expand Up @@ -128,7 +128,7 @@ then the following methods are allowed:
The default choice of `:auto` selects `:hybr` for NonlinearProblem and `:lm` for
NonlinearLeastSquaresProblem.
"""
struct CMINPACK <: AbstractNonlinearAlgorithm
struct CMINPACK <: AbstractNonlinearSolveAlgorithm
show_trace::Bool
tracing::Bool
method::Symbol
Expand Down Expand Up @@ -181,7 +181,7 @@ Choices for methods in `NLsolveJL`:
these arguments, consult the
[NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl).
"""
@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm
@concrete struct NLsolveJL <: AbstractNonlinearSolveAlgorithm
method::Symbol
autodiff::Symbol
store_trace::Bool
Expand Down Expand Up @@ -232,7 +232,7 @@ Fixed Point Problems. We allow using this algorithm to solve root finding proble
- N. Lepage-Saucier, Alternating cyclic extrapolation methods for optimization algorithms,
arXiv:2104.04974 (2021). https://arxiv.org/abs/2104.04974.
"""
@concrete struct SpeedMappingJL <: AbstractNonlinearAlgorithm
@concrete struct SpeedMappingJL <: AbstractNonlinearSolveAlgorithm
σ_min
stabilize::Bool
check_obj::Bool
Expand All @@ -248,3 +248,77 @@ function SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool

return SpeedMappingJL(σ_min, stabilize, check_obj, orders, time_limit)
end

"""
FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
condition_number_threshold = missing, extrapolation_period = missing,
replace_invalids = :NoAction)
Wrapper over [FixedPointAcceleration.jl](https://s-baumann.github.io/FixedPointAcceleration.jl/)
for solving Fixed Point Problems. We allow using this algorithm to solve root finding
problems as well.
## Arguments:
- `algorithm`: The algorithm to use. Can be `:Anderson`, `:MPE`, `:RRE`, `:VEA`, `:SEA`,
`:Simple`, `:Aitken` or `:Newton`.
- `m`: The number of previous iterates to use for the extrapolation. Only valid for
`:Anderson`.
- `condition_number_threshold`: The condition number threshold for Least Squares Problem.
Only valid for `:Anderson`.
- `extrapolation_period`: The number of iterates between extrapolations. Only valid for
`:MPE`, `:RRE`, `:VEA` and `:SEA`. Defaults to `7` for `:MPE` & `:RRE`, and `6` for
`:SEA` and `:VEA`. For `:SEA` and `:VEA`, this must be a multiple of `2`.
- `replace_invalids`: The method to use for replacing invalid iterates. Can be
`:ReplaceInvalids`, `:ReplaceVector` or `:NoAction`.
"""
@concrete struct FixedPointAccelerationJL <: AbstractNonlinearSolveAlgorithm
algorithm::Symbol
extrapolation_period::Int
replace_invalids::Symbol
dampening
m::Int
condition_number_threshold
end

function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
condition_number_threshold = missing, extrapolation_period = missing,
replace_invalids = :NoAction, dampening = 1.0)
if Base.get_extension(@__MODULE__, :NonlinearSolveFixedPointAccelerationExt) === nothing
error("FixedPointAccelerationJL requires FixedPointAcceleration.jl to be loaded")

Check warning on line 288 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L288

Added line #L288 was not covered by tests
end

@assert algorithm in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
@assert replace_invalids in (:ReplaceInvalids, :ReplaceVector, :NoAction)

if algorithm !== :Anderson
if condition_number_threshold !== missing
error("`condition_number_threshold` is only valid for Anderson acceleration")

Check warning on line 296 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L296

Added line #L296 was not covered by tests
end
if m !== missing
error("`m` is only valid for Anderson acceleration")

Check warning on line 299 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L299

Added line #L299 was not covered by tests
end
end
condition_number_threshold === missing && (condition_number_threshold = 1e3)
m === missing && (m = 10)

if algorithm !== :MPE && algorithm !== :RRE && algorithm !== :VEA && algorithm !== :SEA
if extrapolation_period !== missing
error("`extrapolation_period` is only valid for MPE, RRE, VEA and SEA")

Check warning on line 307 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L307

Added line #L307 was not covered by tests
end
end
if extrapolation_period === missing
if algorithm === :SEA || algorithm === :VEA
extrapolation_period = 6
else
extrapolation_period = 7
end
else
if (algorithm === :SEA || algorithm === :VEA) && extrapolation_period % 2 != 0
error("`extrapolation_period` must be multiples of 2 for SEA and VEA")

Check warning on line 318 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L317-L318

Added lines #L317 - L318 were not covered by tests
end
end

return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids,
dampening, m, condition_number_threshold)
end
68 changes: 68 additions & 0 deletions test/fixed_point_acceleration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using NonlinearSolve, FixedPointAcceleration, LinearAlgebra, Test

# Simple Scalar Problem
@testset "Simple Scalar Problem" begin
f1(x, p) = cos(x) - x
prob = NonlinearProblem(f1, 1.1)

for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
@test abs(solve(prob, FixedPointAccelerationJL()).resid) 1e-10
end
end

# Simple Vector Problem
@testset "Simple Vector Problem" begin
f2(x, p) = cos.(x) .- x
prob = NonlinearProblem(f2, [1.1, 1.1])

for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
@test maximum(abs.(solve(prob, FixedPointAccelerationJL()).resid)) 1e-10
end
end

# Fixed Point for Power Method
# Taken from https://github.com/NicolasL-S/SpeedMapping.jl/blob/95951db8f8a4457093090e18802ad382db1c76da/test/runtests.jl
@testset "Power Method" begin
C = [1 2 3; 4 5 6; 7 8 9]
A = C + C'
B = Hermitian(ones(10) * ones(10)' .* im + Diagonal(1:10))

function power_method!(du, u, A)
mul!(du, A, u)
du ./= norm(du, Inf)
du .-= u # Convert to a root finding problem
return nothing
end

prob = NonlinearProblem(power_method!, ones(3), A)

for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg))
if SciMLBase.successful_retcode(sol)
@test sol.u' * A[:, 3] 32.916472867168096
else
@warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)"
@test_broken sol.u' * A[:, 3] 32.916472867168096
end
end

# Non vector inputs
function power_method_nonvec!(du, u, A)
mul!(vec(du), A, vec(u))
du ./= norm(du, Inf)
du .-= u # Convert to a root finding problem
return nothing
end

prob = NonlinearProblem(power_method_nonvec!, ones(1, 3, 1), A)

for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg))
if SciMLBase.successful_retcode(sol)
@test sol.u' * A[:, 3] 32.916472867168096
else
@warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)"
@test_broken sol.u' * A[:, 3] 32.916472867168096
end
end
end

0 comments on commit 29917d8

Please sign in to comment.