Skip to content

Commit

Permalink
refactor: migrate to LineSearch.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 28, 2024
1 parent d6d741b commit 980c4c5
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 484 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -80,6 +81,7 @@ Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LazyArrays = "1.8.2, 2"
LeastSquaresOptim = "0.8.5"
LineSearch = "0.1"
LineSearches = "7.2"
LinearAlgebra = "1.10"
LinearSolve = "2.35"
Expand Down
7 changes: 0 additions & 7 deletions docs/src/devdocs/internal_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ NonlinearSolve.AbstractDampingFunction
NonlinearSolve.AbstractDampingFunctionCache
```

## Line Search

```@docs
NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm
NonlinearSolve.AbstractNonlinearSolveLineSearchCache
```

## Trust Region

```@docs
Expand Down
7 changes: 5 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using LazyArrays: LazyArrays, ApplyArray, cache
using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Symmetric,
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearch: LineSearch, AbstractLineSearchAlgorithm, AbstractLineSearchCache,
NoLineSearch, RobustNonMonotoneLineSearch
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization,
needs_concrete_A, AbstractFactorization,
Expand Down Expand Up @@ -173,8 +175,9 @@ export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcce

# Globalization
## Line Search Algorithms
export LineSearchesJL, NoLineSearch, RobustNonMonotoneLineSearch, LiFukushimaLineSearch
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking
export LineSearchesJL, LiFukushimaLineSearch # FIXME: deprecated. use LineSearch.jl directly
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking # FIXME: deprecated
export NoLineSearch, RobustNonMonotoneLineSearch
## Trust Region Algorithms
export RadiusUpdateSchemes

Expand Down
22 changes: 3 additions & 19 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache)
return true
end

"""
AbstractNonlinearSolveLineSearchAlgorithm
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
### `__internal_init` specification
```julia
__internal_init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F,
fu, u, p, args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN} -->
AbstractNonlinearSolveLineSearchCache
```
"""
abstract type AbstractNonlinearSolveLineSearchAlgorithm end

"""
AbstractNonlinearSolveLineSearchCache
Expand Down Expand Up @@ -512,9 +496,9 @@ SciMLBase.isinplace(::AbstractNonlinearSolveJacobianCache{iip}) where {iip} = ii
abstract type AbstractNonlinearSolveTraceLevel end

# Default Printing
for aType in (AbstractTrustRegionMethod, AbstractNonlinearSolveLineSearchAlgorithm,
AbstractResetCondition, AbstractApproximateJacobianUpdateRule,
AbstractDampingFunction, AbstractNonlinearSolveExtensionAlgorithm)
for aType in (AbstractTrustRegionMethod, AbstractResetCondition,
AbstractApproximateJacobianUpdateRule, AbstractDampingFunction,
AbstractNonlinearSolveExtensionAlgorithm)
@eval function Base.show(io::IO, alg::$(aType))
print(io, "$(nameof(typeof(alg)))()")
end
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ over this.
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn(
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.", :Klement)
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/pseudo_transient.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
An implementation of PseudoTransient Method [coffey2003pseudotransient](@cite) that is used
to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping
Expand All @@ -16,8 +15,8 @@ This implementation specifically uses "switched evolution relaxation"
you are going to need more iterations to converge but it can be more stable.
"""
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing, alpha_initial = 1e-3)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
alpha_initial = 1e-3)
descent = DampedNewtonDescent(; linsolve, precs, initial_damping = alpha_initial,
damping_fn = SwitchedEvolutionRelaxation())
return GeneralizedFirstOrderAlgorithm(;
Expand Down
10 changes: 6 additions & 4 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
linesearch = missing, trustregion = missing, descent, update_rule,
reinit_rule, initialization, max_resets::Int = typemax(Int),
max_shrink_times::Int = typemax(Int)) where {concrete_jac, name}
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
Expand Down Expand Up @@ -199,8 +199,8 @@ function SciMLBase.__init(
if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
linesearch_cache = init(
prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...)
GB = :LineSearch
end

Expand Down Expand Up @@ -317,7 +317,9 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
if descent_result.success
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
needs_reset = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
cache.force_reinit = true
Expand Down
16 changes: 11 additions & 5 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function GeneralizedFirstOrderAlgorithm{concrete_jac, name}(;
jacobian_ad !== nothing && ADTypes.mode(jacobian_ad) isa ADTypes.ReverseMode,
jacobian_ad, nothing))

if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
Expand Down Expand Up @@ -199,8 +199,13 @@ function SciMLBase.__init(
if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
linesearch_ad = alg.forward_ad === nothing ?
(alg.reverse_ad === nothing ? alg.jacobian_ad :
alg.reverse_ad) : alg.forward_ad
linesearch_ad = get_concrete_forward_ad(
linesearch_ad, prob, False; check_forward_mode = false)
linesearch_cache = init(
prob, alg.linesearch, fu, u; stats, autodiff = linesearch_ad, kwargs...)
GB = :LineSearch
end

Expand Down Expand Up @@ -264,8 +269,9 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
cache.make_new_jacobian = true
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(
cache.linesearch_cache, cache.u, δu)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end
if linesearch_failed
cache.retcode = ReturnCode.InternalLineSearchFailed
Expand Down
14 changes: 7 additions & 7 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Method.
### Arguments
- `linesearch`: Globalization using a Line Search Method. This needs to follow the
[`NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm`](@ref) interface. This
is not optional currently, but that restriction might be lifted in the future.
- `linesearch`: Globalization using a Line Search Method. This is not optional currently,
but that restriction might be lifted in the future.
- `σ_min`: The minimum spectral parameter allowed. This is used to ensure that the
spectral parameter is not too small.
- `σ_max`: The maximum spectral parameter allowed. This is used to ensure that the
Expand Down Expand Up @@ -119,7 +118,7 @@ end
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
maxtime = nothing, kwargs...)
timer = get_timer_output()
@static_timeit timer "cache construction" begin
u = __maybe_unaliased(prob.u0, alias_u0)
Expand All @@ -130,8 +129,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
fu = evaluate_f(prob, u)
@bb fu_cache = copy(fu)

linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
stats, maxiters, internalnorm, kwargs...)
linesearch_cache = init(prob, alg.linesearch, fu, u; stats, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
Expand Down Expand Up @@ -167,7 +165,9 @@ function __step!(cache::GeneralizedDFSaneCache{iip};
end

@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(cache.linesearch_cache, cache.u, cache.du)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, cache.du)
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end

if linesearch_failed
Expand Down
9 changes: 6 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ function FastShortcutNonlinearPolyalg(
else
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
Expand All @@ -426,7 +427,8 @@ function FastShortcutNonlinearPolyalg(
SimpleKlement(),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
end
Expand All @@ -444,7 +446,8 @@ function FastShortcutNonlinearPolyalg(
Klement(; linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
Expand Down
Loading

0 comments on commit 980c4c5

Please sign in to comment.