Skip to content

Commit

Permalink
Merge pull request #99 from FHoltorf/TR
Browse files Browse the repository at this point in the history
SimpleTrustRegion bug
  • Loading branch information
ChrisRackauckas authored Nov 23, 2023
2 parents 2ad0b37 + 8fba768 commit 31b8a57
Show file tree
Hide file tree
Showing 21 changed files with 118 additions and 117 deletions.
10 changes: 5 additions & 5 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ function __init__()
end

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedBroyden;
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
alg::BatchedBroyden;
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

u, f, reconstruct = _construct_batched_problem_structure(prob)
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing,
args...; kwargs...)
args...; kwargs...)
SciMLBase.solve(prob, ITP(), args...; kwargs...)
end

Expand Down
36 changes: 18 additions & 18 deletions lib/SimpleNonlinearSolve/src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
Expand All @@ -50,9 +50,9 @@ end
# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
Expand All @@ -61,13 +61,13 @@ for Alg in [Bisection]
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:Dual{T,
V,
P},
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:AbstractArray{
<:Dual{T,
V,
P},
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/alefeld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
struct Alefeld <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem,
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
c = a - (b - a) / (f(b) - f(a)) * f(a)
Expand Down
12 changes: 6 additions & 6 deletions lib/SimpleNonlinearSolve/src/batched/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ Base.@kwdef struct BatchedSimpleDFSane{T, F, TC <: NLSolveTerminationCondition}
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedSimpleDFSane,
args...;
abstol = nothing,
reltol = nothing,
maxiters = 100,
kwargs...)
alg::BatchedSimpleDFSane,
args...;
abstol = nothing,
reltol = nothing,
maxiters = 100,
kwargs...)
iip = isinplace(prob)

u, f, reconstruct = _construct_batched_problem_structure(prob)
Expand Down
12 changes: 6 additions & 6 deletions lib/SimpleNonlinearSolve/src/batched/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ alg_autodiff(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} =
diff_type(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = FDT

function BatchedSimpleNewtonRaphson(; chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
autodiff = Val{true}(),
diff_type = Val{:forward},
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
return BatchedSimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size),
SciMLBase._unwrap_val(autodiff),
SciMLBase._unwrap_val(diff_type), typeof(termination_condition)}(termination_condition)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
iip &&
@assert alg_autodiff(alg) "Inplace BatchedSimpleNewtonRaphson currently only supports autodiff."
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/batched/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ function _construct_batched_problem_structure(prob)
end

function _construct_batched_problem_structure(u0::AbstractArray{T, N},
f,
p,
::Val{iip}) where {T, N, iip}
f,
p,
::Val{iip}) where {T, N, iip}
# Reconstruct `u`
reconstruct = N == 2 ? identity : Base.Fix2(reshape, size(u0))
# Standardize `u`
Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ function Bisection(; exact_left = false, exact_right = false)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)
Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ A non-allocating Brent method
struct Brent <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
fa, fb = f(a), f(b)
Expand Down
8 changes: 4 additions & 4 deletions lib/SimpleNonlinearSolve/src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ struct Broyden{TC <: NLSolveTerminationCondition} <:
end

function Broyden(; batched = false,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
if batched
@assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden."
return BatchedBroyden(termination_condition)
Expand All @@ -29,7 +29,7 @@ function Broyden(; batched = false,
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
Expand Down
18 changes: 9 additions & 9 deletions lib/SimpleNonlinearSolve/src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ struct SimpleDFSane{T, TC} <: AbstractSimpleNonlinearSolveAlgorithm
end

function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
batched::Bool = false,
max_inner_iterations = 1000)
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
batched::Bool = false,
max_inner_iterations = 1000)
if batched
return BatchedSimpleDFSane(; σₘᵢₙ = σ_min,
σₘₐₓ = σ_max,
Expand All @@ -98,8 +98,8 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
kwargs...)
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
struct Falsi <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)
Expand Down
8 changes: 4 additions & 4 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ and static array problems.
"""
struct SimpleHalley{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
function SimpleHalley(; chunk_size = Val{0}(), autodiff = Val{true}(),
diff_type = Val{:forward})
diff_type = Val{:forward})
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
SciMLBase._unwrap_val(diff_type)}()
end
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::SimpleHalley, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
alg::SimpleHalley, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = f(x)
Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ struct ITP{T} <: AbstractBracketingAlgorithm
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP,
args...; abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
maxiters = 1000, kwargs...)
args...; abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan # a and b
fl, fr = f(left), f(right)
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ This method is non-allocating on scalar problems.
struct Klement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::Klement, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
alg::Klement, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fₙ = f(x)
Expand Down
18 changes: 9 additions & 9 deletions lib/SimpleNonlinearSolve/src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ struct LBroyden{batched, TC <: NLSolveTerminationCondition} <:
threshold::Int

function LBroyden(; batched = false, threshold::Int = 27,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
return new{batched, typeof(termination_condition)}(termination_condition, threshold)
end
end

@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden{batched}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
kwargs...) where {batched}
abstol = nothing, reltol = nothing, maxiters = 1000,
kwargs...) where {batched}
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
threshold = min(maxiters, alg.threshold)
Expand Down Expand Up @@ -116,26 +116,26 @@ function _init_lbroyden_state(batched::Bool, x, threshold)
end

function _rmatvec(U::AbstractMatrix, Vᵀ::AbstractMatrix,
x::Union{<:AbstractVector, <:Number})
x::Union{<:AbstractVector, <:Number})
length(U) == 0 && return x
return -x .+ vec((x' * Vᵀ) * U)
end

function _rmatvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3},
x::AbstractMatrix) where {T1, T2}
x::AbstractMatrix) where {T1, T2}
length(U) == 0 && return x
Vᵀx = sum(Vᵀ .* reshape(x, size(x, 1), 1, size(x, 2)); dims = 1)
return -x .+ _drdims_sum(U .* permutedims(Vᵀx, (2, 1, 3)); dims = 1)
end

function _matvec(U::AbstractMatrix, Vᵀ::AbstractMatrix,
x::Union{<:AbstractVector, <:Number})
x::Union{<:AbstractVector, <:Number})
length(U) == 0 && return x
return -x .+ vec(Vᵀ * (U * x))
end

function _matvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3},
x::AbstractMatrix) where {T1, T2}
x::AbstractMatrix) where {T1, T2}
length(U) == 0 && return x
xUᵀ = sum(reshape(x, size(x, 1), 1, size(x, 2)) .* permutedims(U, (2, 1, 3)); dims = 1)
return -x .+ _drdims_sum(xUᵀ .* Vᵀ; dims = 2)
Expand Down
19 changes: 10 additions & 9 deletions lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ and static array problems.
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end

function SimpleNewtonRaphson(; batched = false,
chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
termination_condition = missing)
chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
termination_condition = missing)
if !ismissing(termination_condition) && !batched
throw(ArgumentError("`termination_condition` is currently only supported for batched problems"))
end
Expand All @@ -63,10 +63,10 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem,NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
Expand All @@ -76,7 +76,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem,NonlinearLeastSquaresPro
error("SimpleNewtonRaphson currently only supports out-of-place nonlinear problems")
end

if prob isa NonlinearLeastSquaresProblem && !(typeof(prob.u0) <: Union{Number, AbstractVector})
if prob isa NonlinearLeastSquaresProblem &&
!(typeof(prob.u0) <: Union{Number, AbstractVector})
error("SimpleGaussNewton only supports Number and AbstactVector types. Please convert any problem of AbstractArray into one with u0 as AbstractVector")
end

Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ A non-allocating ridder method
struct Ridder <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)
Expand Down
Loading

0 comments on commit 31b8a57

Please sign in to comment.