Skip to content

Commit

Permalink
Merge pull request #203 from SciML/ap/nonlinearsolve
Browse files Browse the repository at this point in the history
[Breaking] Use NonlinearSolve for all root finding needs
  • Loading branch information
ChrisRackauckas authored Feb 22, 2024
2 parents be310b8 + 090a7cb commit 89663bb
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 130 deletions.
26 changes: 13 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "2.37.0"
version = "3.0.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -10,7 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -23,33 +23,34 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[compat]
Aqua = "0.8"
DataInterpolations = "4"
DataInterpolations = "4.6"
DataStructures = "0.18.13"
DiffEqBase = "6.141"
ForwardDiff = "0.10.19"
DiffEqBase = "6.146"
ForwardDiff = "0.10.36"
Functors = "0.4"
LinearAlgebra = "1.10"
Markdown = "1.10"
NLsolve = "4.5"
NonlinearSolve = "3.6"
ODEProblemLibrary = "0.1.5"
OrdinaryDiffEq = "6.68"
Parameters = "0.12"
QuadGK = "2.4"
RecipesBase = "1.1"
RecursiveArrayTools = "2.38, 3"
SciMLBase = "2.9"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.9"
SciMLBase = "2.26"
SciMLSensitivity = "7.49"
StaticArrays = "1.8"
StaticArraysCore = "1.4"
Sundials = "4.19.2"
Test = "1"
Tracker = "0.2.15"
Zygote = "0.6.61"
Tracker = "0.2.26"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -61,5 +62,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote"]

test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote", "NonlinearSolve"]
2 changes: 1 addition & 1 deletion src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqCallbacks

using DiffEqBase, RecursiveArrayTools, DataStructures, RecipesBase, LinearAlgebra,
StaticArraysCore, NLsolve, ForwardDiff, Functors
StaticArraysCore, NonlinearSolve, ForwardDiff, Functors

import Base.Iterators

Expand Down
40 changes: 18 additions & 22 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
if dtcache == dt
if integrator.opts.verbose
@warn("Could not restrict values to domain. Iteration was canceled since ",
"proposed time step dt = ", dt," could not be reduced.")
"proposed time step dt = ", dt, " could not be reduced.")
end
break
end
Expand Down Expand Up @@ -212,12 +212,10 @@ end
# callback definitions

"""
```julia
GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
```
GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
a `PositiveDomain` callback to arbitrary domains. Domains are specified by
Expand All @@ -242,41 +240,39 @@ preferred.
## Keyword Arguments
- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
which is passed to a `ManifoldProjection`.
- `save`: Whether to do the standard saving (applied after the callback).
- `abstol`: Tolerance up to, which residuals are accepted. Element-wise tolerances
are allowed. If it is not specified, every application of the callback uses the
current absolute tolerances of the integrator.
- `scalefactor`: Factor by which an unaccepted time step is reduced. If it is not
specified, time steps are halved.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
- `nlopts`: Optional arguments to nonlinear solver of a `ManifoldProjection` which
can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
The default value of `ftol = 10*eps()` ensures that convergence is only declared
if the infinite norm of residuals is very small and hence the state vector is very
close to the domain.
If it is not specified, it is determined automatically.
- `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.
## References
Shampine, Lawrence F., Skip Thompson, Jacek Kierzenka and G. D. Byrne.
Non-negative solutions of ODEs. Applied Mathematics and Computation 170
(2005): 556-569.
"""
function GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
function GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
_autonomous = SciMLBase._unwrap_val(autonomous)
if u isa Nothing
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, nothing, nothing)
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing)
else
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, deepcopy(u),
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u),
deepcopy(u))
end
condition = (u, t, integrator) -> true
CallbackSet(
ManifoldProjection(g; nlsolve = nlsolve, save = false,
autonomous = autonomous, nlopts = nlopts),
ManifoldProjection(
g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...),
DiscreteCallback(condition, affect!; save_positions = (false, save)))
end

Expand Down
181 changes: 102 additions & 79 deletions src/manifold.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,28 @@
Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
determine_chunksize(u, get_chunksize(alg))
end
Base.@pure function determine_chunksize(u, CS)
if CS != 0
return CS
else
return ForwardDiff.pickchunksize(length(u))
end
end

struct NLSOLVEJL_SETUP{CS, AD} end
Base.@pure function NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true)
NLSOLVEJL_SETUP{chunk_size, autodiff}()
end
(::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)
function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
AD ? autodiff = :forward : autodiff = :central
OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
end

# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{F, autonomous}
mutable struct NonAutonomousFunction{iip, F, autonomous}
f::F
t::Any
p::Any
end
(p::NonAutonomousFunction{F, true})(res, u) where {F} = p.f(res, u, p.p)
(p::NonAutonomousFunction{F, false})(res, u) where {F} = p.f(res, u, p.p, p.t)

(f::NonAutonomousFunction{true, F, true})(res, u, p) where {F} = f.f(res, u, p)
(f::NonAutonomousFunction{true, F, false})(res, u, p) where {F} = f.f(res, u, p, f.t)

(f::NonAutonomousFunction{false, F, true})(u, p) where {F} = f.f(u, p)
(f::NonAutonomousFunction{false, F, false})(u, p) where {F} = f.f(u, p, f.t)

SciMLBase.isinplace(::NonAutonomousFunction{iip}) where {iip} = iip

"""
```julia
ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true)
```
In many cases, you may want to declare a manifold on which a solution lives.
Mathematically, a manifold `M` is defined by a function `g` as the set of points
where `g(u)=0`. An embedded manifold can be a lower dimensional object which
constrains the solution. For example, `g(u)=E(u)-C` where `E` is the energy
of the system in state `u`, meaning that the energy must be constant (energy
preservation). Thus by defining the manifold the solution should live on, you
can retain desired properties of the solution.
ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
In many cases, you may want to declare a manifold on which a solution lives. Mathematically,
a manifold `M` is defined by a function `g` as the set of points where `g(u) = 0`. An
embedded manifold can be a lower dimensional object which constrains the solution. For
example, `g(u) = E(u) - C` where `E` is the energy of the system in state `u`, meaning that
the energy must be constant (energy preservation). Thus by defining the manifold the
solution should live on, you can retain desired properties of the solution.
`ManifoldProjection` projects the solution of the differential equation to the chosen
manifold `g`, conserving a property while conserving the order. It is a consequence of
Expand All @@ -52,80 +34,121 @@ properties.
## Arguments
- `g`: The residual function for the manifold. This is an inplace function of form
`g(resid, u)` or `g(resid, u, p, t)` which writes to the residual `resid` the
difference from the manifold components. Here, it is assumed that `resid` is of
the same shape as `u`.
- `g`: The residual function for the manifold.
* This is an inplace function of form `g(resid, u, p)` or `g(resid, u, p, t)` which
writes to the residual `resid` the difference from the manifold components. Here, it
is assumed that `resid` is of the same shape as `u`.
* If `isinplace = Val(false)`, then `g` should be a function of the form `g(u, p)` or
`g(u, p, t)` which returns the residual.
## Keyword Arguments
- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
- `nlsolve`: A nonlinear solver as defined in the
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
- `save`: Whether to do the standard saving (applied after the callback)
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u)`.
- `nlopts`: Optional arguments to nonlinear solver which can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
- `nlls`: If the problem is a nonlinear least squares problem. `nlls = Val(false)`
generates a `NonlinearProblem` which is typically faster than
`NonlinearLeastSquaresProblem`, but is only applicable if the residual size is same as
the state size.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
`g(u, p)`. Specify it as `Val(::Bool)` to ensure this function call is type stable.
- `residual_prototype`: This needs to be specified if `nlls = Val(true)` and
`inplace = Val(true)` are specified together, else it is taken to be same as `u`.
- `kwargs`: All other keyword arguments are passed to
[NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).
### Saveat Warning
Note that the `ManifoldProjection` callback modifies the endpoints of the integration intervals
and thus breaks assumptions of internal interpolations. Because of this, the values for given by
saveat will not be order-matching. However, the interpolation error can be proportional to the
change by the projection, so if the projection is making small changes then one is still safe.
However, if there are large changes from each projection, you should consider only saving at
stopping/projection times. To do this, set `tstops` to the same values as `saveat`. There is a
performance hit by doing so because now the integrator is forced to stop at every saving point,
but this is guerenteed to match the order of the integrator even with the ManifoldProjection.
Note that the `ManifoldProjection` callback modifies the endpoints of the integration
intervals and thus breaks assumptions of internal interpolations. Because of this, the
values for given by saveat will not be order-matching. However, the interpolation error can
be proportional to the change by the projection, so if the projection is making small
changes then one is still safe. However, if there are large changes from each projection,
you should consider only saving at stopping/projection times. To do this, set `tstops` to
the same values as `saveat`. There is a performance hit by doing so because now the
integrator is forced to stop at every saving point, but this is guerenteed to match the
order of the integrator even with the ManifoldProjection.
## References
Ernst Hairer, Christian Lubich, Gerhard Wanner. Geometric Numerical Integration:
Structure-Preserving Algorithms for Ordinary Differential Equations. Berlin ;
New York :Springer, 2002.
"""
mutable struct ManifoldProjection{autonomous, F, NL, O}
mutable struct ManifoldProjection{iip, nlls, autonomous, F, NL, R, K}
g::F
nl_rhs::Any
nlcache::Any
nlsolve::NL
nlopts::O
resid_prototype::R
kwargs::K

function ManifoldProjection{autonomous}(g, nlsolve, nlopts) where {autonomous}
function ManifoldProjection{iip, nlls, autonomous}(
g, nlsolve, resid_prototype, kwargs) where {iip, nlls, autonomous}
# replace residual function if it is time-dependent
# since NLsolve only accepts functions with two arguments
_g = NonAutonomousFunction{typeof(g), autonomous}(g, 0, 0)
new{autonomous, typeof(_g), typeof(nlsolve), typeof(nlopts)}(_g, _g, nlsolve,
nlopts)
_g = NonAutonomousFunction{iip, typeof(g), autonomous}(g, 0)
return new{iip, nlls, autonomous, typeof(_g), typeof(nlsolve),
typeof(resid_prototype), typeof(kwargs)}(
_g, nothing, nlsolve, resid_prototype, kwargs)
end
end

# Now make `affect!` for this:
function (p::ManifoldProjection{autonomous, NL})(integrator) where {autonomous, NL}
function (p::ManifoldProjection{
iip, nlls, autonomous, NL})(integrator) where {iip, nlls,
autonomous, NL}
# update current time if residual function is time-dependent
if !autonomous
p.g.t = integrator.t
end
p.g.p = integrator.p
!autonomous && (p.g.t = integrator.t)

integrator.u .= p.nlsolve(p.nl_rhs, integrator.u; p.nlopts...)
end
# solve the nonlinear problem
reinit!(p.nlcache, integrator.u; p = integrator.p)
sol = solve!(p.nlcache)

function Manifold_initialize(cb, u::Number, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, [u])
u_modified!(integrator, false)
if !SciMLBase.successful_retcode(sol)
SciMLBase.terminate!(integrator, sol.retcode)
return
end

copyto!(integrator.u, sol.u)
end

function Manifold_initialize(cb, u, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, u)
return Manifold_initialize(cb.affect!, u, t, integrator)
end
function Manifold_initialize(
affect!::ManifoldProjection{iip, nlls}, u, t, integrator) where {iip, nlls}
nlfunc = NonlinearFunction{iip}(affect!.g; affect!.resid_prototype)
nlprob = if nlls
NonlinearLeastSquaresProblem(nlfunc, u, integrator.p)
else
NonlinearProblem(nlfunc, u, integrator.p)
end
affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
u_modified!(integrator, false)
end

function ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict{Symbol, Any}())
affect! = ManifoldProjection{autonomous}(g, nlsolve, nlopts)
# Since this is applied to every point, we can reasonably assume that the solution is close
# to the initial guess, so we would want to use NewtonRaphson / RobustMultiNewton instead of
# the default one.
function ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
# `nothing` is a valid solver, so this need to be `missing`
_nlls = SciMLBase._unwrap_val(nlls)
_nlsolve = nlsolve === missing ? (_nlls ? GaussNewton() : NewtonRaphson()) : nlsolve
iip = SciMLBase._unwrap_val(isinplace)
if autonomous === nothing
if iip
autonomous = maximum(SciMLBase.numargs(g)) == 3
else
autonomous = maximum(SciMLBase.numargs(g)) == 2
end
end
affect! = ManifoldProjection{iip, _nlls, SciMLBase._unwrap_val(autonomous)}(
g, _nlsolve, resid_prototype, kwargs)
condition = (u, t, integrator) -> true
save_positions = (false, save)
DiscreteCallback(condition, affect!;
initialize = Manifold_initialize,
save_positions = save_positions)
return DiscreteCallback(condition, affect!; initialize = Manifold_initialize,
save_positions = (false, save))
end

export ManifoldProjection
2 changes: 1 addition & 1 deletion src/saving.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ function linearize_period(t₀, t₁, u₀, u₁, integ, ilsc, caches, u_mask,
# Sanity check that we don't accidentally infinitely recurse
if t₁ - t₀ < dtmin
@debug("Linearization failure",
t₁, t₀, string(u₀), string(u₁), string(u_mask),dtmin)
t₁, t₀, string(u₀), string(u₁), string(u_mask), dtmin)
throw(ArgumentError("Linearization failed, fell below linearization subdivision threshold"))
end

Expand Down
Loading

0 comments on commit 89663bb

Please sign in to comment.