Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIAMFANLEquations wrapper #333

Merged
merged 19 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -47,6 +48,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLsolveExt = "NLsolve"
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSymbolicsExt = "Symbolics"
NonlinearSolveZygoteExt = "Zygote"
Expand All @@ -55,8 +57,8 @@ NonlinearSolveZygoteExt = "Zygote"
ADTypes = "0.2.5"
Aqua = "0.8"
ArrayInterface = "7.7"
BandedMatrices = "1.3"
BenchmarkTools = "1"
BandedMatrices = "1.4"
BenchmarkTools = "1.4"
ConcreteStructs = "0.2"
DiffEqBase = "6.144"
EnumX = "1"
Expand Down Expand Up @@ -86,6 +88,7 @@ Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.11"
SciMLOperators = "0.3.7"
SIAMFANLEquations = "1.0.1"
SimpleNonlinearSolve = "1.0.2"
SparseArrays = "1.9"
SparseDiffTools = "2.14"
Expand Down Expand Up @@ -118,6 +121,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -127,4 +131,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"]
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"]
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pages = ["index.md",
"api/leastsquaresoptim.md",
"api/fastlevenbergmarquardt.md",
"api/speedmapping.md",
"api/fixedpointacceleration.md"],
"api/fixedpointacceleration.md",
"api/siamfanlequations.md"]],
"Release Notes" => "release_notes.md",
]
2 changes: 1 addition & 1 deletion docs/src/api/fastlevenbergmarquardt.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FastLevenbergMarquardt.jl

This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/leastsquaresoptim.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# LeastSquaresOptim.jl

This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML
This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

Expand Down
17 changes: 17 additions & 0 deletions docs/src/api/siamfanlequations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SIAMFANLEquations.jl

This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
interface. Note that these solvers do not come by default, and thus one needs to install
the package before using these solvers:

```julia
using Pkg
Pkg.add("SIAMFANLEquations")
using SIAMFANLEquations, NonlinearSolve
```

## Solver API

```@docs
SIAMFANLEquationsJL
```
141 changes: 141 additions & 0 deletions ext/NonlinearSolveSIAMFANLEquationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
module NonlinearSolveSIAMFANLEquationsExt

using NonlinearSolve, SciMLBase
using SIAMFANLEquations
import ConcreteStructs: @concrete
import UnPack: @unpack
import FiniteDiff, ForwardDiff

function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing,

Check warning on line 9 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L9

Added line #L9 was not covered by tests
reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"

Check warning on line 11 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L11

Added line #L11 was not covered by tests

@unpack method, autodiff, show_trace, delta, linsolve = alg

Check warning on line 13 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L13

Added line #L13 was not covered by tests

iip = SciMLBase.isinplace(prob)
T = eltype(u0)

Check warning on line 16 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L15-L16

Added lines #L15 - L16 were not covered by tests

atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol

Check warning on line 19 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L18-L19

Added lines #L18 - L19 were not covered by tests

if prob.u0 isa Number
f! = if iip
function (u)
du = similar(u)
prob.f(du, u, prob.p)
return du

Check warning on line 26 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L21-L26

Added lines #L21 - L26 were not covered by tests
end
else
u -> prob.f(u, prob.p)

Check warning on line 29 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L29

Added line #L29 was not covered by tests
end

if method == :newton
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace)
elseif method == :secant
sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)

Check warning on line 37 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L32-L37

Added lines #L32 - L37 were not covered by tests
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 47 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L40-L47

Added lines #L40 - L47 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)

Check warning on line 50 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
else
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)

Check warning on line 52 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L52

Added line #L52 was not covered by tests
end

fu = NonlinearSolve.evaluate_f(prob, u)

Check warning on line 55 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L55

Added line #L55 was not covered by tests

if iip
f! = function (du, u)
prob.f(du, u, prob.p)
return du

Check warning on line 60 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L57-L60

Added lines #L57 - L60 were not covered by tests
end
else
f! = function (du, u)
du .= prob.f(u, prob.p)
return du

Check warning on line 65 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
end
end

# Allocate ahead for function
N = length(u)
FS = zeros(T, N)

Check warning on line 71 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L70-L71

Added lines #L70 - L71 were not covered by tests

# Jacobian free Newton Krylov
if linsolve !== nothing

Check warning on line 74 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L74

Added line #L74 was not covered by tests
# Allocate ahead for Krylov basis
JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N)

Check warning on line 76 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L76

Added line #L76 was not covered by tests
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
linsolve_alg = String(linsolve)

Check warning on line 78 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L78

Added line #L78 was not covered by tests

if method == :newton
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)

Check warning on line 83 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L80-L83

Added lines #L80 - L83 were not covered by tests
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 93 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L86-L93

Added lines #L86 - L93 were not covered by tests
end
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)

Check warning on line 96 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L95-L96

Added lines #L95 - L96 were not covered by tests
end

# Allocate ahead for Jacobian
FPS = zeros(T, N, N)
if prob.f.jac === nothing

Check warning on line 101 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L100-L101

Added lines #L100 - L101 were not covered by tests
# Use the built-in Jacobian machinery
if method == :newton
sol = nsol(f!, u, FS, FPS;

Check warning on line 104 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsol(f!, u, FS, FPS;

Check warning on line 108 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
atol = atol, rtol = rtol, maxit = maxiters,
delta0 = delta, printerr = show_trace)
end
else
AJ!(J, u, x) = prob.f.jac(J, x, prob.p)
if method == :newton
sol = nsol(f!, u, FS, FPS, AJ!;

Check warning on line 115 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L113-L115

Added lines #L113 - L115 were not covered by tests
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
printerr = show_trace)
elseif method == :pseudotransient
sol = ptcsol(f!, u, FS, FPS, AJ!;

Check warning on line 119 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
atol = atol, rtol = rtol, maxit = maxiters,
delta0 = delta, printerr = show_trace)
end
end

if sol.errcode == 0
retcode = ReturnCode.Success
elseif sol.errcode == 10
retcode = ReturnCode.MaxIters
elseif sol.errcode == 1
retcode = ReturnCode.Failure
elseif sol.errcode == -1
retcode = ReturnCode.Default

Check warning on line 132 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L125-L132

Added lines #L125 - L132 were not covered by tests
end

# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
println(sol.stats)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)

Check warning on line 138 in ext/NonlinearSolveSIAMFANLEquationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveSIAMFANLEquationsExt.jl#L136-L138

Added lines #L136 - L138 were not covered by tests
end

end
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ export RadiusUpdateSchemes
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
Broyden, Klement, LimitedMemoryBroyden
export LeastSquaresOptimJL,
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

Expand Down
35 changes: 35 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
end

"""

SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)

Expand Down Expand Up @@ -322,3 +323,37 @@
return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids,
dampening, m, condition_number_threshold)
end

"""

SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)

### Keyword Arguments

- `method`: the choice of method for solving the nonlinear system.
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
central differencing via FiniteDiff.jl. The other choices are `:forward`.
- `show_trace`: whether to show the trace.
- `delta`: initial pseudo time step, default is 1e-3.
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.

### Submethod Choice

- `:newton`: Classical Newton method.
- `:pseudotransient`: Pseudo transient method.
- `:secant`: Secant method for scalar equations.
"""
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
method::Symbol
autodiff::Symbol
show_trace::Bool
delta
linsolve::Union{Symbol, Nothing}
end

function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")

Check warning on line 356 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L354-L356

Added lines #L354 - L356 were not covered by tests
end
return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve)

Check warning on line 358 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L358

Added line #L358 was not covered by tests
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
if GROUP == "All" || GROUP == "Wrappers"
@time @safetestset "MINPACK" include("minpack.jl")
@time @safetestset "NLsolve" include("nlsolve.jl")
@time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl")
@time @safetestset "SpeedMapping" include("speedmapping.jl")
@time @safetestset "FixedPointAcceleration" include("fixed_point_acceleration.jl")
end
Expand Down
Loading
Loading