-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #337 from avik-pal/ap/fixed_point
FixedPointSolvers: Recommendations and Wrappers
- Loading branch information
Showing
15 changed files
with
440 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# FixedPointAcceleration.jl | ||
|
||
This is a extension for importing solvers from FixedPointAcceleration.jl into the SciML | ||
interface. Note that these solvers do not come by default, and thus one needs to install | ||
the package before using these solvers: | ||
|
||
```julia | ||
using Pkg | ||
Pkg.add("FixedPointAcceleration") | ||
using FixedPointAcceleration, NonlinearSolve | ||
``` | ||
|
||
## Solver API | ||
|
||
```@docs | ||
FixedPointAccelerationJL | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# SpeedMapping.jl | ||
|
||
This is a extension for importing solvers from SpeedMapping.jl into the SciML | ||
interface. Note that these solvers do not come by default, and thus one needs to install | ||
the package before using these solvers: | ||
|
||
```julia | ||
using Pkg | ||
Pkg.add("SpeedMapping") | ||
using SpeedMapping, NonlinearSolve | ||
``` | ||
|
||
## Solver API | ||
|
||
```@docs | ||
SpeedMappingJL | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Fixed Point Solvers | ||
|
||
Currently we don't have an API to directly specify Fixed Point Solvers. However, a Fixed | ||
Point Problem can be trivially converted to a Root Finding Problem. Say we want to solve: | ||
|
||
```math | ||
f(u) = u | ||
``` | ||
|
||
This can be written as: | ||
|
||
```math | ||
g(u) = f(u) - u = 0 | ||
``` | ||
|
||
``g(u) = 0`` is a root finding problem. Note that we can use any root finding | ||
algorithm to solve this problem. However, this is often not the most efficient way to | ||
solve a fixed point problem. We provide a few algorithms available via extensions that | ||
are more efficient for fixed point problems. | ||
|
||
Note that even if you use one of the Fixed Point Solvers mentioned here, you must still | ||
use the `NonlinearProblem` API to specify the problem, i.e., ``g(u) = 0``. | ||
|
||
## Recommended Methods | ||
|
||
Using [native NonlinearSolve.jl methods](@ref nonlinearsystemsolvers) is the recommended | ||
approach. For systems where constructing Jacobian Matrices are expensive, we recommend | ||
using a Krylov Method with one of those solvers. | ||
|
||
## Full List of Methods | ||
|
||
We are only listing the methods that natively solve fixed point problems. | ||
|
||
### SpeedMapping.jl | ||
|
||
- `SpeedMappingJL()`: accelerates the convergence of a mapping to a fixed point by the | ||
Alternating cyclic extrapolation algorithm (ACX). | ||
|
||
### FixedPointAcceleration.jl | ||
|
||
- `FixedPointAccelerationJL()`: accelerates the convergence of a mapping to a fixed point | ||
by the Anderson acceleration algorithm and a few other methods. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
module NonlinearSolveFixedPointAccelerationExt | ||
|
||
using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase | ||
|
||
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...; | ||
abstol = nothing, maxiters = 1000, alias_u0::Bool = false, | ||
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing, | ||
kwargs...) where {PrintReports} | ||
@assert (termination_condition === | ||
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!" | ||
|
||
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) | ||
u_size = size(u0) | ||
T = eltype(u0) | ||
iip = isinplace(prob) | ||
p = prob.p | ||
|
||
if !iip && prob.u0 isa Number | ||
# FixedPointAcceleration makes the scalar problem into a vector problem | ||
f = (u) -> [prob.f(u[1], p) .+ u[1]] | ||
elseif !iip && prob.u0 isa AbstractVector | ||
f = (u) -> (prob.f(u, p) .+ u) | ||
elseif !iip && prob.u0 isa AbstractArray | ||
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u) | ||
elseif iip && prob.u0 isa AbstractVector | ||
du = similar(u0) | ||
f = (u) -> (prob.f(du, u, p); du .+ u) | ||
else | ||
du = similar(u0) | ||
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u) | ||
end | ||
|
||
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol | ||
|
||
sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm, | ||
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m, | ||
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening, | ||
PrintReports, ReplaceInvalids = alg.replace_invalids, | ||
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true) | ||
|
||
res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_ | ||
if res === missing | ||
resid = NonlinearSolve.evaluate_f(prob, u0) | ||
res = u0 | ||
converged = false | ||
else | ||
resid = NonlinearSolve.evaluate_f(prob, res) | ||
converged = maximum(abs, resid) ≤ tol | ||
end | ||
return SciMLBase.build_solution(prob, alg, res, resid; | ||
retcode = converged ? ReturnCode.Success : ReturnCode.Failure, | ||
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_), | ||
original = sol) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
module NonlinearSolveSpeedMappingExt | ||
|
||
using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase | ||
|
||
function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...; | ||
abstol = nothing, maxiters = 1000, alias_u0::Bool = false, | ||
store_trace::Val{store_info} = Val(false), termination_condition = nothing, | ||
kwargs...) where {store_info} | ||
@assert (termination_condition === | ||
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!" | ||
|
||
if typeof(prob.u0) <: Number | ||
u0 = [prob.u0] | ||
else | ||
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) | ||
end | ||
|
||
T = eltype(u0) | ||
iip = isinplace(prob) | ||
p = prob.p | ||
|
||
if !iip && prob.u0 isa Number | ||
m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u)) | ||
elseif !iip | ||
m! = (du, u) -> (du .= prob.f(u, p) .+ u) | ||
else | ||
m! = (du, u) -> (prob.f(du, u, p); du .+= u) | ||
end | ||
|
||
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol | ||
|
||
sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders, | ||
alg.check_obj, store_info, alg.σ_min, alg.stabilize) | ||
res = prob.u0 isa Number ? first(sol.minimizer) : sol.minimizer | ||
resid = NonlinearSolve.evaluate_f(prob, sol.minimizer) | ||
|
||
return SciMLBase.build_solution(prob, alg, res, resid; | ||
retcode = sol.converged ? ReturnCode.Success : ReturnCode.Failure, | ||
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.