Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIAMFANLEquations wrapper #333

Merged
merged 19 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -47,6 +48,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLsolveExt = "NLsolve"
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSymbolicsExt = "Symbolics"
NonlinearSolveZygoteExt = "Zygote"
Expand All @@ -55,8 +57,8 @@ NonlinearSolveZygoteExt = "Zygote"
ADTypes = "0.2.5"
Aqua = "0.8"
ArrayInterface = "7.7"
BandedMatrices = "1.3"
BenchmarkTools = "1"
BandedMatrices = "1.4"
BenchmarkTools = "1.4"
ConcreteStructs = "0.2"
DiffEqBase = "6.144"
EnumX = "1"
Expand Down Expand Up @@ -86,6 +88,7 @@ Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.11"
SciMLOperators = "0.3.7"
SIAMFANLEquations = "1.0.1"
SimpleNonlinearSolve = "1.0.2"
SparseArrays = "1.9"
SparseDiffTools = "2.14"
Expand Down Expand Up @@ -118,6 +121,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -127,4 +131,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"]
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pages = ["index.md",
"api/sundials.md",
"api/steadystatediffeq.md",
"api/leastsquaresoptim.md",
"api/fastlevenbergmarquardt.md"],
"api/fastlevenbergmarquardt.md",
"api/siamfanlequations.md"],
"Release Notes" => "release_notes.md",
]
2 changes: 1 addition & 1 deletion docs/src/api/fastlevenbergmarquardt.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FastLevenbergMarquardt.jl

This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/leastsquaresoptim.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# LeastSquaresOptim.jl

This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML
This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
17 changes: 17 additions & 0 deletions docs/src/api/siamfanlequations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SIAMFANLEquations.jl

This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

```julia
using Pkg
Pkg.add("SIAMFANLEquations")
using SIAMFANLEquations, NonlinearSolve
```

## Solver API

```@docs
SIAMFANLEquationsJL
```
171 changes: 171 additions & 0 deletions ext/NonlinearSolveSIAMFANLEquationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
module NonlinearSolveSIAMFANLEquationsExt

using NonlinearSolve, SciMLBase
using SIAMFANLEquations
import ConcreteStructs: @concrete
import UnPack: @unpack
import FiniteDiff, ForwardDiff

function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8,
reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"

@unpack method, autodiff, show_trace, delta, linsolve = alg

iip = SciMLBase.isinplace(prob)
if typeof(prob.u0) <: Number
f! = if iip
function (u)
du = similar(u)
prob.f(du, u, prob.p)
return du

Check warning on line 21 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L18-L21

Added lines #L18 - L21 were not covered by tests
end
else
u -> prob.f(u, prob.p)
end

if method == :newton
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
elseif method == :secant
sol = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)

Check warning on line 32 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L29-L32

Added lines #L29 - L32 were not covered by tests
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 42 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L37-L42

Added lines #L37 - L42 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
else
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

fu = NonlinearSolve.evaluate_f(prob, u)

if iip
f! = function (du, u)
prob.f(du, u, prob.p)
return du
end
else
f! = function (du, u)
du .= prob.f(u, prob.p)
return du
end
end

# Allocate ahead for function and Jacobian
N = length(u)
FS = zeros(eltype(u), N)
FPS = zeros(eltype(u), N, N)
# Allocate ahead for Krylov basis

# Jacobian free Newton Krylov
if linsolve !== nothing
JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N)

Check warning on line 72 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L72

Added line #L72 was not covered by tests
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
linsolve_alg = String(linsolve)

Check warning on line 74 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L74

Added line #L74 was not covered by tests

if method == :newton
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)

Check warning on line 79 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L76-L79

Added lines #L76 - L79 were not covered by tests
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 89 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L82-L89

Added lines #L82 - L89 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)

Check warning on line 92 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

if prob.f.jac === nothing
use_forward_diff = if alg.autodiff === nothing
ForwardDiff.can_dual(eltype(u))

Check warning on line 97 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L97

Added line #L97 was not covered by tests
else
alg.autodiff isa AutoForwardDiff
end
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
if use_forward_diff
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :

Check warning on line 103 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L103

Added line #L103 was not covered by tests
ForwardDiff.JacobianConfig(uf, u)
else
cache = FiniteDiff.JacobianCache(u, fu)
end
J! = if iip
if use_forward_diff
fu_cache = similar(fu)
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
return J

Check warning on line 114 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L110-L114

Added lines #L110 - L114 were not covered by tests
end
else
function (J, x, p)
uf.p = p
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return J
end
end
else
if use_forward_diff
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, x, cache)
return J

Check warning on line 128 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
end
else
function (J, x, p)
uf.p = p
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
copyto!(J, J_)
return J
end
end
end
else
J! = prob.f.jac
end

AJ!(J, u, x) = J!(J, x, prob.p)

if method == :newton
sol = nsol(f!, u, FS, FPS, AJ!;
sham=1, rtol = reltol, atol = abstol, maxit = maxiters,
printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsol(f!, u, FS, FPS, AJ!;

Check warning on line 150 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
rtol = reltol, atol = abstol, maxit = maxiters,
delta0 = delta, printerr = show_trace)
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 162 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L157-L162

Added lines #L157 - L162 were not covered by tests
end

# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
println(sol.stats)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
end

end
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ export RadiusUpdateSchemes
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
Broyden, Klement, LimitedMemoryBroyden
export LeastSquaresOptimJL,
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

Expand Down
35 changes: 35 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
end

"""

SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)

Expand Down Expand Up @@ -322,3 +323,37 @@
return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids,
dampening, m, condition_number_threshold)
end

"""

SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)

### Keyword Arguments

- `method`: the choice of method for solving the nonlinear system.
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
central differencing via FiniteDiff.jl. The other choices are `:forward`.
- `show_trace`: whether to show the trace.
- `delta`: initial pseudo time step, default is 1e-3.
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.

### Submethod Choice

- `:newton`: Classical Newton method.
- `:pseudotransient`: Pseudo transient method.
- `:secant`: Secant method for scalar equations.
"""
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
method::Symbol
autodiff::Symbol
show_trace::Bool
delta
linsolve::Union{Symbol, Nothing}
end

function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")

Check warning on line 356 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L356

Added line #L356 was not covered by tests
end
return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
if GROUP == "All" || GROUP == "Wrappers"
@time @safetestset "MINPACK" include("minpack.jl")
@time @safetestset "NLsolve" include("nlsolve.jl")
@time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl")
@time @safetestset "SpeedMapping" include("speedmapping.jl")
@time @safetestset "FixedPointAcceleration" include("fixed_point_acceleration.jl")
end
Expand Down
Loading
Loading