Skip to content

Commit

Permalink
Merge pull request #116 from SciML/ap/ls
Browse files Browse the repository at this point in the history
Add Line Search to (L)Broyden
  • Loading branch information
Vaibhavdixit02 authored Jan 14, 2024
2 parents 6eb690e + 34289ef commit eae355d
Show file tree
Hide file tree
Showing 18 changed files with 301 additions and 104 deletions.
13 changes: 9 additions & 4 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.2.1"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -17,17 +18,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"

[weakdeps]
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"

[compat]
ADTypes = "0.2.6"
ArrayInterface = "7"
ConcreteStructs = "0.2"
DiffEqBase = "6.126"
FastClosures = "0.3"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LinearAlgebra = "1.9"
Expand All @@ -36,4 +40,5 @@ PrecompileTools = "1"
Reexport = "1"
SciMLBase = "2.7"
StaticArraysCore = "1.4"
StaticArrays = "1"
julia = "1.9"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module SimpleNonlinearSolveStaticArraysExt

using SimpleNonlinearSolve

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true

end
10 changes: 5 additions & 5 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@ module SimpleNonlinearSolve
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations

@recompile_invalidations begin
using ADTypes,
ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, SciMLBase
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
ForwardDiff, Reexport, LinearAlgebra, SciMLBase

import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
using FiniteDiff, ForwardDiff
NONLINEARSOLVE_DEFAULT_NORM
import ForwardDiff: Dual
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
end

Expand All @@ -26,6 +25,7 @@ abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm e
@inline __is_extension_loaded(::Val) = false

include("utils.jl")
include("linesearch.jl")

## Nonlinear Solvers
include("nlsolve/raphson.jl")
Expand Down
1 change: 0 additions & 1 deletion lib/SimpleNonlinearSolve/src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
sol.original)
end

# Handle Ambiguities
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@eval begin
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/bracketing/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
left, right = prob.tspan
fl, fr = f(left), f(right)

abstol = _get_tolerance(abstol,
abstol = __get_tolerance(nothing, abstol,
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))

if iszero(fl)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/bracketing/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
fl, fr = f(left), f(right)
ϵ = eps(convert(typeof(fl), 1))

abstol = _get_tolerance(abstol,
abstol = __get_tolerance(nothing, abstol,
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))

if iszero(fl)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/bracketing/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
left, right = prob.tspan
fl, fr = f(left), f(right)

abstol = _get_tolerance(abstol,
abstol = __get_tolerance(nothing, abstol,
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))

if iszero(fl)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/bracketing/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
left, right = prob.tspan
fl, fr = f(left), f(right)

abstol = _get_tolerance(abstol,
abstol = __get_tolerance(nothing, abstol,
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))

if iszero(fl)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/bracketing/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
left, right = prob.tspan
fl, fr = f(left), f(right)

abstol = _get_tolerance(abstol,
abstol = __get_tolerance(nothing, abstol,
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))

if iszero(fl)
Expand Down
128 changes: 128 additions & 0 deletions lib/SimpleNonlinearSolve/src/linesearch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# This is a copy of the version in NonlinearSolve.jl. Temporarily kept here till we move
# line searches into a dedicated package.
@kwdef @concrete struct LiFukushimaLineSearch
lambda_0 = 1
beta = 0.5
sigma_1 = 0.001
sigma_2 = 0.001
eta = 0.1
rho = 0.1
nan_maxiters = missing
maxiters::Int = 100
end

@concrete mutable struct LiFukushimaLineSearchCache{T <: Union{Nothing, Int}}
ϕ
λ₀
β
σ₁
σ₂
η
ρ
α
nan_maxiters::T
maxiters::Int
end

@concrete struct StaticLiFukushimaLineSearchCache
f
p
λ₀
β
σ₁
σ₂
η
ρ
maxiters::Int
end

(alg::LiFukushimaLineSearch)(prob, fu, u) = __generic_init(alg, prob, fu, u)
function (alg::LiFukushimaLineSearch)(prob, fu::Union{Number, SArray},
u::Union{Number, SArray})
(alg.nan_maxiters === missing || alg.nan_maxiters === nothing) &&
return __static_init(alg, prob, fu, u)
@warn "`LiFukushimaLineSearch` with NaN checking is not non-allocating" maxlog=1
return __generic_init(alg, prob, fu, u)
end

function __generic_init(alg::LiFukushimaLineSearch, prob, fu, u)
@bb u_cache = similar(u)
@bb fu_cache = similar(fu)
T = promote_type(eltype(fu), eltype(u))

ϕ = @closure (u, δu, α) -> begin
@bb @. u_cache = u + α * δu
return NONLINEARSOLVE_DEFAULT_NORM(__eval_f(prob, fu_cache, u_cache))
end

nan_maxiters = ifelse(alg.nan_maxiters === missing, 5, alg.nan_maxiters)

return LiFukushimaLineSearchCache(ϕ, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1),
T(alg.sigma_2), T(alg.eta), T(alg.rho), T(true), nan_maxiters, alg.maxiters)
end

function __static_init(alg::LiFukushimaLineSearch, prob, fu, u)
T = promote_type(eltype(fu), eltype(u))
return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0), T(alg.beta),
T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), alg.maxiters)
end

function (cache::LiFukushimaLineSearchCache)(u, δu)
T = promote_type(eltype(u), eltype(δu))
ϕ = @closure α -> cache.ϕ(u, δu, α)
fx_norm = ϕ(T(0))

# Non-Blocking exit if the norm is NaN or Inf
DiffEqBase.NAN_CHECK(fx_norm) && return cache.α

# Early Terminate based on Eq. 2.7
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
fxλ_norm = ϕ(cache.α)
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return cache.α

λ₂, λ₁ = cache.λ₀, cache.λ₀
fxλp_norm = ϕ(λ₂)

if cache.nan_maxiters !== nothing
if DiffEqBase.NAN_CHECK(fxλp_norm)
nan_converged = false
for _ in 1:(cache.nan_maxiters)
λ₁, λ₂ = λ₂, cache.β * λ₂
fxλp_norm = ϕ(λ₂)
nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool
nan_converged && break
end
nan_converged || return cache.α
end
end

for i in 1:(cache.maxiters)
fxλp_norm = ϕ(λ₂)
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
converged && return λ₂
λ₁, λ₂ = λ₂, cache.β * λ₂
end

return cache.α
end

function (cache::StaticLiFukushimaLineSearchCache)(u, δu)
T = promote_type(eltype(u), eltype(δu))

# Early Terminate based on Eq. 2.7
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u, cache.p))
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
fxλ_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ δu, cache.p))
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return T(true)

λ₂, λ₁ = cache.λ₀, cache.λ₀

for i in 1:(cache.maxiters)
fxλp_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ λ₂ .* δu, cache.p))
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
converged && return λ₂
λ₁, λ₂ = λ₂, cache.β * λ₂
end

return T(true)
end
46 changes: 41 additions & 5 deletions lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
"""
SimpleBroyden()
SimpleBroyden(; linesearch = Val(false), alpha = nothing)
A low-overhead implementation of Broyden. This method is non-allocating on scalar
and static array problems.
### Keyword Arguments
- `linesearch`: If `linesearch` is `Val(true)`, then we use the `LiFukushimaLineSearch`
[1] line search else no line search is used. For advanced customization of the line
search, use the [`Broyden`](@ref) algorithm in `NonlinearSolve.jl`.
- `alpha`: Scale the initial jacobian initialization with `alpha`. If it is `nothing`, we
will compute the scaling using `2 * norm(fu) / max(norm(u), true)`.
### References
[1] Li, Dong-Hui, and Masao Fukushima. "A derivative-free line search and global convergence
of Broyden-like method for nonlinear equations." Optimization methods and software 13.3
(2000): 181-201.
"""
struct SimpleBroyden <: AbstractSimpleNonlinearSolveAlgorithm end
@concrete struct SimpleBroyden{linesearch} <: AbstractSimpleNonlinearSolveAlgorithm
alpha
end

function SimpleBroyden(; linesearch = Val(false), alpha = nothing)
return SimpleBroyden{_unwrap_val(linesearch)}(alpha)
end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
fx = _get_fx(prob, x)
T = promote_type(eltype(x), eltype(fx))

@bb xo = copy(x)
@bb δx = copy(x)
@bb δf = copy(fx)
@bb fprev = copy(fx)

J⁻¹ = __init_identity_jacobian(fx, x)
if alg.alpha === nothing
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(fx)
x_norm = NONLINEARSOLVE_DEFAULT_NORM(x)
init_α = ifelse(fx_norm 1e-5, max(x_norm, T(true)) / (2 * fx_norm), T(true))
else
init_α = inv(alg.alpha)
end

J⁻¹ = __init_identity_jacobian(fx, x, init_α)
@bb J⁻¹δf = copy(x)
@bb xᵀJ⁻¹ = copy(x)
@bb δJ⁻¹n = copy(x)
Expand All @@ -26,9 +57,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
termination_condition)

ls_cache = __get_linesearch(alg) === Val(true) ?
LiFukushimaLineSearch()(prob, fx, x) : nothing

for _ in 1:maxiters
@bb δx = J⁻¹ × vec(fprev)
@bb @. x = xo - δx
@bb δx .*= -1

α = ls_cache === nothing ? true : ls_cache(xo, δx)
@bb @. x = xo + α * δx
fx = __eval_f(prob, fx, x)
@bb @. δf = fx - fprev

Expand All @@ -37,7 +74,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
tc_sol !== nothing && return tc_sol

@bb J⁻¹δf = J⁻¹ × vec(δf)
@bb δx .*= -1
d = dot(δx, J⁻¹δf)
@bb xᵀJ⁻¹ = transpose(J⁻¹) × vec(δx)

Expand Down
Loading

0 comments on commit eae355d

Please sign in to comment.