Skip to content

Commit

Permalink
Make default polyalgs respect autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 26, 2024
1 parent 7897f20 commit 716209f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.7.2"
version = "3.7.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
43 changes: 24 additions & 19 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,13 @@ function FastShortcutNonlinearPolyalg(
# and thus are not included in the polyalgorithm
if SA
if __is_complex(T)
algs = (SimpleBroyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
algs = (SimpleBroyden(),
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
SimpleKlement(),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
else
algs = (SimpleBroyden(),
Broyden(; init_jacobian = Val(:true_jacobian)),
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
SimpleKlement(),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
Expand All @@ -322,13 +323,13 @@ function FastShortcutNonlinearPolyalg(
end
else
if __is_complex(T)
algs = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
Klement(; linsolve, precs),
algs = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
Klement(; linsolve, prec, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
else
algs = (Broyden(),
Broyden(; init_jacobian = Val(:true_jacobian)),
Klement(; linsolve, precs),
algs = (Broyden(; autodiff),
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
Klement(; linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
Expand All @@ -343,7 +344,7 @@ end

"""
FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, kwargs...)
precs = DEFAULT_PRECS, autodiff = nothing, kwargs...)
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
for more performance and then tries more robust techniques if the faster ones fail.
Expand All @@ -353,21 +354,25 @@ for more performance and then tries more robust techniques if the faster ones fa
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
are compatible with the problem type. Defaults to `Float64`.
"""
function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T}
function FastShortcutNLLSPolyalg(
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, autodiff = nothing, kwargs...) where {T}
if __is_complex(T)
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
LevenbergMarquardt(; linsolve, precs, kwargs...))
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
LevenbergMarquardt(;
linsolve, precs, autodiff, disable_geodesic = Val(true), kwargs...),
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
else
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
TrustRegion(; concrete_jac, linsolve, precs, kwargs...),
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
LevenbergMarquardt(;
linsolve, precs, disable_geodesic = Val(true), autodiff, kwargs...),
TrustRegion(; concrete_jac, linsolve, precs, autodiff, kwargs...),
GaussNewton(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), kwargs...),
linesearch = LineSearchesJL(; method = BackTracking()),
autodiff, kwargs...),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, kwargs...),
LevenbergMarquardt(; linsolve, precs, kwargs...))
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff, kwargs...),
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
end
return NonlinearSolvePolyAlgorithm(algs, Val(:NLLS))
end
Expand Down
26 changes: 26 additions & 0 deletions test/misc/polyalg_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,32 @@ end
@test SciMLBase.successful_retcode(sol)
end

@testitem "PolyAlgorithms Autodiff" begin
cache = zeros(2)
function f(du, u, p)
cache .= u .* u
du .= cache .- 2
end
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(f, u0)

custom_polyalg = NonlinearSolvePolyAlgorithm((
Broyden(; autodiff = AutoFiniteDiff()), LimitedMemoryBroyden()))

# Uses the `__solve` function
solver = solve(probN; abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
@test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(
probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(probN, custom_polyalg; abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
end

@testitem "Simple Scalar Problem #187" begin
using NaNMath

Expand Down

0 comments on commit 716209f

Please sign in to comment.