diff --git a/Project.toml b/Project.toml index c1f21c7..cf5a675 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.15" +version = "0.1.16" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/dfsane.jl b/src/dfsane.jl index e898f06..b5e2b82 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -1,9 +1,12 @@ """ -```julia -SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, - M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, - nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2) -``` + SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, + M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, + nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), + batched::Bool = false, + max_inner_iterations::Int = 1000) A low-overhead implementation of the df-sane method for solving large-scale nonlinear systems of equations. For in depth information about all the parameters and the algorithm, @@ -39,8 +42,16 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_ ``f_1=||F(x_1)||^{nexp}``, `k` is the iteration number, `x` is the current `x`-value and `F` the current residual. Should satisfy ``η_k > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to ``||F||^2 / k^2``. +- `termination_condition`: a `NLSolveTerminationCondition` that determines when the solver + should terminate. Defaults to `NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, reltol = nothing)`. +- `batched`: if `true`, the algorithm will use a batched version of the algorithm that treats each + column of `x` as a separate problem. This can be useful nonlinear problems involing neural + networks. Defaults to `false`. +- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the + algorithm. Used exclusively in `batched` mode. Defaults to `1000`. """ -struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm +struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm σ_min::T σ_max::T σ_1::T @@ -50,23 +61,49 @@ struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm τ_max::T nexp::Int η_strategy::Function + termination_condition::TC + max_inner_iterations::Int function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, - nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2) - new{typeof(σ_min)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy) + nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), + batched::Bool = false, + max_inner_iterations = 1000) + return new{batched, typeof(σ_min), typeof(termination_condition)}(σ_min, + σ_max, + σ_1, + M, + γ, + τ_min, + τ_max, + nexp, + η_strategy, + termination_condition, + max_inner_iterations) end end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, +function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) + kwargs...) where {batched} + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) + + if batched + batch_size = size(x, 2) + end + T = eltype(x) σ_min = float(alg.σ_min) σ_max = float(alg.σ_max) - σ_k = float(alg.σ_1) + σ_k = batched ? fill(float(alg.σ_1), 1, batch_size) : float(alg.σ_1) + M = alg.M γ = float(alg.γ) τ_min = float(alg.τ_min) @@ -74,74 +111,125 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, nexp = alg.nexp η_strategy = alg.η_strategy + batched && @assert ndims(x)==2 "Batched SimpleDFSane only supports 2D arrays" + if SciMLBase.isinplace(prob) error("SimpleDFSane currently only supports out-of-place nonlinear problems") end atol = abstol !== nothing ? abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) - rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("SimpleDFSane currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + termination_condition = tc(storage) function ff(x) F = f(x) - f_k = norm(F)^nexp + f_k = if batched + sum(abs2, F; dims = 1) .^ (nexp / 2) + else + norm(F)^nexp + end return f_k, F end + function generate_history(f_k, M) + if batched + history = similar(f_k, (M, length(f_k))) + history .= reshape(f_k, 1, :) + return history + else + return fill(f_k, M) + end + end + f_k, F_k = ff(x) α_1 = convert(T, 1.0) f_1 = f_k - history_f_k = fill(f_k, M) + history_f_k = generate_history(f_k, M) for k in 1:maxiters - iszero(F_k) && - return SciMLBase.build_solution(prob, alg, x, F_k; - retcode = ReturnCode.Success) - # Spectral parameter range check - if abs(σ_k) > σ_max - σ_k = sign(σ_k) * σ_max - elseif abs(σ_k) < σ_min - σ_k = sign(σ_k) * σ_min + if batched + @. σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) + else + σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) end # Line search direction - d = -σ_k * F_k + d = -σ_k .* F_k η = η_strategy(f_1, k, x, F_k) - f̄ = maximum(history_f_k) + f̄ = batched ? maximum(history_f_k; dims = 1) : maximum(history_f_k) α_p = α_1 α_m = α_1 - x_new = x + α_p * d + x_new = @. x + α_p * d + f_new, F_new = ff(x_new) + + inner_iterations = 0 while true - if f_new ≤ f̄ + η - γ * α_p^2 * f_k - break + inner_iterations += 1 + + if batched + criteria = @. f̄ + η - γ * α_p^2 * f_k + # NOTE: This is simply a heuristic, ideally we check using `all` but that is + # typically very expensive for large problems + (sum(f_new .≤ criteria) ≥ batch_size ÷ 2) && break + else + criteria = f̄ + η - γ * α_p^2 * f_k + f_new ≤ criteria && break end - α_tp = α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k) - x_new = x - α_m * d + α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k) + x_new = @. x - α_m * d f_new, F_new = ff(x_new) - if f_new ≤ f̄ + η - γ * α_m^2 * f_k - break + if batched + # NOTE: This is simply a heuristic, ideally we check using `all` but that is + # typically very expensive for large problems + (sum(f_new .≤ criteria) ≥ batch_size ÷ 2) && break + else + f_new ≤ criteria && break end - α_tm = α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k) - α_p = min(τ_max * α_p, max(α_tp, τ_min * α_p)) - α_m = min(τ_max * α_m, max(α_tm, τ_min * α_m)) - x_new = x + α_p * d + α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k) + α_p = @. clamp(α_tp, τ_min * α_p, τ_max * α_p) + α_m = @. clamp(α_tm, τ_min * α_m, τ_max * α_m) + x_new = @. x + α_p * d f_new, F_new = ff(x_new) + + # NOTE: The original algorithm runs till either condition is satisfied, however, + # for most batched problems like neural networks we only care about + # approximate convergence + batched && (inner_iterations ≥ alg.max_inner_iterations) && break end - if isapprox(x_new, x, atol = atol, rtol = rtol) - return SciMLBase.build_solution(prob, alg, x_new, F_new; + if termination_condition(F_new, x_new, x, atol, rtol) + return SciMLBase.build_solution(prob, + alg, + x_new, + F_new; retcode = ReturnCode.Success) end + # Update spectral parameter - s_k = x_new - x - y_k = F_new - F_k - σ_k = (s_k' * s_k) / (s_k' * y_k) + s_k = @. x_new - x + y_k = @. F_new - F_k + + if batched + σ_k = sum(abs2, s_k; dims = 1) ./ (sum(s_k .* y_k; dims = 1) .+ T(1e-5)) + else + σ_k = (s_k' * s_k) / (s_k' * y_k) + end # Take step x = x_new @@ -149,7 +237,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, f_k = f_new # Store function value - history_f_k[k % M + 1] = f_new + if batched + history_f_k[k % M + 1, :] .= vec(f_new) + else + history_f_k[k % M + 1] = f_new + end end return SciMLBase.build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters) end diff --git a/test/basictests.jl b/test/basictests.jl index e1d6f9b..aea1731 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -9,6 +9,8 @@ const BATCHED_BROYDEN_SOLVERS = Broyden[] const BROYDEN_SOLVERS = Broyden[] const BATCHED_LBROYDEN_SOLVERS = LBroyden[] const LBROYDEN_SOLVERS = LBroyden[] +const BATCHED_DFSANE_SOLVERS = SimpleDFSane[] +const DFSANE_SOLVERS = SimpleDFSane[] for mode in instances(NLSolveTerminationMode.T) if mode ∈ @@ -23,6 +25,8 @@ for mode in instances(NLSolveTerminationMode.T) push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition)) push!(LBROYDEN_SOLVERS, LBroyden(; batched = false, termination_condition)) push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition)) + push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition)) + push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition)) end # SimpleNewtonRaphson @@ -484,11 +488,13 @@ sol = solve(probN, Broyden(batched = true)) @test abs.(sol.u) ≈ sqrt.(p) -for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS...) - sol = solve(probN, alg) +for alg in (BATCHED_BROYDEN_SOLVERS..., + BATCHED_LBROYDEN_SOLVERS..., + BATCHED_DFSANE_SOLVERS...) + sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3) @test sol.retcode == ReturnCode.Success - @test abs.(sol.u) ≈ sqrt.(p) + @test abs.(sol.u)≈sqrt.(p) atol=1e-3 rtol=1e-3 end ## User specified Jacobian