Skip to content

Commit

Permalink
refactor: migrate to LineSearch.jl (#461)
Browse files Browse the repository at this point in the history
* refactor: migrate to LineSearch.jl

* fix: forward reinit_cache to SciMLBase.reinit

* fix: remaining tests from the migration

* chore: bump min versions

* fix: final set of bumps

* fix: stop using deprecated API in default solvers

* docs: fix references
  • Loading branch information
avik-pal authored Oct 4, 2024
1 parent f1969a2 commit c35f0f4
Show file tree
Hide file tree
Showing 18 changed files with 105 additions and 514 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.15.0-DEV"
version = "3.15.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -14,6 +14,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -79,11 +80,12 @@ Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LazyArrays = "1.8.2, 2"
LeastSquaresOptim = "0.8.5"
LineSearches = "7.2"
LineSearch = "0.1.2"
LineSearches = "7.3"
LinearAlgebra = "1.10"
LinearSolve = "2.35"
MINPACK = "1.2"
MaybeInplace = "0.1.3"
MaybeInplace = "0.1.4"
ModelingToolkit = "9.41.0"
NLSolvers = "0.5"
NLsolve = "4.5"
Expand Down
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))

interlinks = InterLinks(
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
"LineSearch" => "https://sciml.github.io/LineSearch.jl/dev/"
)

makedocs(; sitename = "NonlinearSolve.jl",
makedocs(;
sitename = "NonlinearSolve.jl",
authors = "Chris Rackauckas",
modules = [NonlinearSolve, SimpleNonlinearSolve, SteadyStateDiffEq,
Sundials, DiffEqBase, SciMLBase, SciMLJacobianOperators],
Expand All @@ -30,6 +32,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
plugins = [bib, interlinks],
format = Documenter.HTML(assets = ["assets/favicon.ico", "assets/citations.css"],
canonical = "https://docs.sciml.ai/NonlinearSolve/stable/"),
pages)
pages
)

deploydocs(repo = "github.com/SciML/NonlinearSolve.jl.git"; push_preview = true)
7 changes: 0 additions & 7 deletions docs/src/devdocs/internal_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ NonlinearSolve.AbstractDampingFunction
NonlinearSolve.AbstractDampingFunctionCache
```

## Line Search

```@docs
NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm
NonlinearSolve.AbstractNonlinearSolveLineSearchCache
```

## Trust Region

```@docs
Expand Down
9 changes: 3 additions & 6 deletions docs/src/native/globalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ Pages = ["globalization.md"]

## [Line Search Algorithms](@id line-search)

```@docs
LiFukushimaLineSearch
LineSearchesJL
RobustNonMonotoneLineSearch
NoLineSearch
```
Line Searches have been moved to an external package. Take a look at the
[LineSearch.jl](https://github.com/SciML/LineSearch.jl) package and its
[documentation](https://sciml.github.io/LineSearch.jl/dev/).

## Radius Update Schemes for Trust Region

Expand Down
2 changes: 1 addition & 1 deletion docs/src/native/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ documentation.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`NoLineSearch()`](@ref),
- `linesearch`: the line search algorithm to use. Defaults to [`NoLineSearch()`](@extref LineSearch.NoLineSearch),
which means that no line search is performed. Algorithms from
[`LineSearches.jl`](https://github.com/JuliaNLSolvers/LineSearches.jl/) must be
wrapped in [`LineSearchesJL`](@ref) before being supplied. For a detailed documentation
Expand Down
7 changes: 5 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ using LazyArrays: LazyArrays, ApplyArray, cache
using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Symmetric,
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearch: LineSearch, AbstractLineSearchAlgorithm, AbstractLineSearchCache,
NoLineSearch, RobustNonMonotoneLineSearch
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization,
needs_concrete_A, AbstractFactorization,
Expand Down Expand Up @@ -172,8 +174,9 @@ export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcce

# Globalization
## Line Search Algorithms
export LineSearchesJL, NoLineSearch, RobustNonMonotoneLineSearch, LiFukushimaLineSearch
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking
export LineSearchesJL, LiFukushimaLineSearch # FIXME: deprecated. use LineSearch.jl directly
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking # FIXME: deprecated
export NoLineSearch, RobustNonMonotoneLineSearch
## Trust Region Algorithms
export RadiusUpdateSchemes

Expand Down
22 changes: 3 additions & 19 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache)
return true
end

"""
AbstractNonlinearSolveLineSearchAlgorithm
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
### `__internal_init` specification
```julia
__internal_init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F,
fu, u, p, args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN} -->
AbstractNonlinearSolveLineSearchCache
```
"""
abstract type AbstractNonlinearSolveLineSearchAlgorithm end

"""
AbstractNonlinearSolveLineSearchCache
Expand Down Expand Up @@ -512,9 +496,9 @@ SciMLBase.isinplace(::AbstractNonlinearSolveJacobianCache{iip}) where {iip} = ii
abstract type AbstractNonlinearSolveTraceLevel end

# Default Printing
for aType in (AbstractTrustRegionMethod, AbstractNonlinearSolveLineSearchAlgorithm,
AbstractResetCondition, AbstractApproximateJacobianUpdateRule,
AbstractDampingFunction, AbstractNonlinearSolveExtensionAlgorithm)
for aType in (AbstractTrustRegionMethod, AbstractResetCondition,
AbstractApproximateJacobianUpdateRule, AbstractDampingFunction,
AbstractNonlinearSolveExtensionAlgorithm)
@eval function Base.show(io::IO, alg::$(aType))
print(io, "$(nameof(typeof(alg)))()")
end
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ over this.
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn(
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.", :Klement)
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/pseudo_transient.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
An implementation of PseudoTransient Method [coffey2003pseudotransient](@cite) that is used
to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping
Expand All @@ -16,8 +15,8 @@ This implementation specifically uses "switched evolution relaxation"
you are going to need more iterations to converge but it can be more stable.
"""
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing, alpha_initial = 1e-3)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
alpha_initial = 1e-3)
descent = DampedNewtonDescent(; linsolve, precs, initial_damping = alpha_initial,
damping_fn = SwitchedEvolutionRelaxation())
return GeneralizedFirstOrderAlgorithm(;
Expand Down
10 changes: 6 additions & 4 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
linesearch = missing, trustregion = missing, descent, update_rule,
reinit_rule, initialization, max_resets::Int = typemax(Int),
max_shrink_times::Int = typemax(Int)) where {concrete_jac, name}
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
Expand Down Expand Up @@ -199,8 +199,8 @@ function SciMLBase.__init(
if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
linesearch_cache = init(
prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...)
GB = :LineSearch
end

Expand Down Expand Up @@ -317,7 +317,9 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
if descent_result.success
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
needs_reset = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
cache.force_reinit = true
Expand Down
20 changes: 15 additions & 5 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function GeneralizedFirstOrderAlgorithm{concrete_jac, name}(;
jacobian_ad !== nothing && ADTypes.mode(jacobian_ad) isa ADTypes.ReverseMode,
jacobian_ad, nothing))

if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
Expand Down Expand Up @@ -199,8 +199,17 @@ function SciMLBase.__init(
if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
linesearch_ad = alg.forward_ad === nothing ?
(alg.reverse_ad === nothing ? alg.jacobian_ad :
alg.reverse_ad) : alg.forward_ad
if linesearch_ad !== nothing && iip && !DI.check_inplace(linesearch_ad)
@warn "$(linesearch_ad) doesn't support in-place problems."
linesearch_ad = nothing
end
linesearch_ad = get_concrete_forward_ad(
linesearch_ad, prob, False; check_forward_mode = false)
linesearch_cache = init(
prob, alg.linesearch, fu, u; stats, autodiff = linesearch_ad, kwargs...)
GB = :LineSearch
end

Expand Down Expand Up @@ -264,8 +273,9 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
cache.make_new_jacobian = true
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(
cache.linesearch_cache, cache.u, δu)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end
if linesearch_failed
cache.retcode = ReturnCode.InternalLineSearchFailed
Expand Down
14 changes: 7 additions & 7 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Method.
### Arguments
- `linesearch`: Globalization using a Line Search Method. This needs to follow the
[`NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm`](@ref) interface. This
is not optional currently, but that restriction might be lifted in the future.
- `linesearch`: Globalization using a Line Search Method. This is not optional currently,
but that restriction might be lifted in the future.
- `σ_min`: The minimum spectral parameter allowed. This is used to ensure that the
spectral parameter is not too small.
- `σ_max`: The maximum spectral parameter allowed. This is used to ensure that the
Expand Down Expand Up @@ -119,7 +118,7 @@ end
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
maxtime = nothing, kwargs...)
timer = get_timer_output()
@static_timeit timer "cache construction" begin
u = __maybe_unaliased(prob.u0, alias_u0)
Expand All @@ -130,8 +129,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
fu = evaluate_f(prob, u)
@bb fu_cache = copy(fu)

linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
stats, maxiters, internalnorm, kwargs...)
linesearch_cache = init(prob, alg.linesearch, fu, u; stats, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
Expand Down Expand Up @@ -167,7 +165,9 @@ function __step!(cache::GeneralizedDFSaneCache{iip};
end

@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(cache.linesearch_cache, cache.u, cache.du)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, cache.du)
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end

if linesearch_failed
Expand Down
22 changes: 16 additions & 6 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
radius_update_scheme = RadiusUpdateSchemes.Bastin),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearch.LineSearchesJL(;
method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.NLsolve, autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
Expand Down Expand Up @@ -405,7 +407,9 @@ function FastShortcutNonlinearPolyalg(
else
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearch.LineSearchesJL(;
method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
Expand All @@ -426,7 +430,9 @@ function FastShortcutNonlinearPolyalg(
SimpleKlement(),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearch.LineSearchesJL(;
method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
end
Expand All @@ -439,12 +445,15 @@ function FastShortcutNonlinearPolyalg(
else
# TODO: This number requires a bit rigorous testing
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
algs = (Broyden(; autodiff),
algs = (
Broyden(; autodiff),
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
Klement(; linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearch.LineSearchesJL(;
method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
Expand Down Expand Up @@ -480,7 +489,8 @@ function FastShortcutNLLSPolyalg(
linsolve, precs, disable_geodesic = Val(true), autodiff, kwargs...),
TrustRegion(; concrete_jac, linsolve, precs, autodiff, kwargs...),
GaussNewton(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()),
linesearch = LineSearch.LineSearchesJL(;
method = LineSearches.BackTracking()),
autodiff, kwargs...),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff, kwargs...),
Expand Down
Loading

0 comments on commit c35f0f4

Please sign in to comment.