Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Use NonlinearSolve for all root finding needs #203

Merged
merged 12 commits into from
Feb 22, 2024
26 changes: 13 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "2.37.0"
version = "3.0.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -10,7 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -23,33 +23,34 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[compat]
Aqua = "0.8"
DataInterpolations = "4"
DataInterpolations = "4.6"
DataStructures = "0.18.13"
DiffEqBase = "6.141"
ForwardDiff = "0.10.19"
DiffEqBase = "6.146"
ForwardDiff = "0.10.36"
Functors = "0.4"
LinearAlgebra = "1.10"
Markdown = "1.10"
NLsolve = "4.5"
NonlinearSolve = "3.6"
ODEProblemLibrary = "0.1.5"
OrdinaryDiffEq = "6.68"
Parameters = "0.12"
QuadGK = "2.4"
RecipesBase = "1.1"
RecursiveArrayTools = "2.38, 3"
SciMLBase = "2.9"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.9"
SciMLBase = "2.26"
SciMLSensitivity = "7.49"
StaticArrays = "1.8"
StaticArraysCore = "1.4"
Sundials = "4.19.2"
Test = "1"
Tracker = "0.2.15"
Zygote = "0.6.61"
Tracker = "0.2.26"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -61,5 +62,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote"]

test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote", "NonlinearSolve"]
2 changes: 1 addition & 1 deletion src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqCallbacks

using DiffEqBase, RecursiveArrayTools, DataStructures, RecipesBase, LinearAlgebra,
StaticArraysCore, NLsolve, ForwardDiff, Functors
StaticArraysCore, NonlinearSolve, ForwardDiff, Functors

import Base.Iterators

Expand Down
40 changes: 18 additions & 22 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
if dtcache == dt
if integrator.opts.verbose
@warn("Could not restrict values to domain. Iteration was canceled since ",
"proposed time step dt = ", dt," could not be reduced.")
"proposed time step dt = ", dt, " could not be reduced.")
end
break
end
Expand Down Expand Up @@ -212,12 +212,10 @@
# callback definitions

"""
```julia
GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
```
GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)

A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
a `PositiveDomain` callback to arbitrary domains. Domains are specified by
Expand All @@ -242,41 +240,39 @@

## Keyword Arguments

- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
which is passed to a `ManifoldProjection`.
- `save`: Whether to do the standard saving (applied after the callback).
- `abstol`: Tolerance up to, which residuals are accepted. Element-wise tolerances
are allowed. If it is not specified, every application of the callback uses the
current absolute tolerances of the integrator.
- `scalefactor`: Factor by which an unaccepted time step is reduced. If it is not
specified, time steps are halved.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
- `nlopts`: Optional arguments to nonlinear solver of a `ManifoldProjection` which
can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
The default value of `ftol = 10*eps()` ensures that convergence is only declared
if the infinite norm of residuals is very small and hence the state vector is very
close to the domain.
If it is not specified, it is determined automatically.
- `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.

## References

Shampine, Lawrence F., Skip Thompson, Jacek Kierzenka and G. D. Byrne.
Non-negative solutions of ODEs. Applied Mathematics and Computation 170
(2005): 556-569.
"""
function GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
function GeneralDomain(

Check warning on line 261 in src/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/domain.jl#L261

Added line #L261 was not covered by tests
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
_autonomous = SciMLBase._unwrap_val(autonomous)

Check warning on line 265 in src/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/domain.jl#L265

Added line #L265 was not covered by tests
if u isa Nothing
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, nothing, nothing)
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing)

Check warning on line 267 in src/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/domain.jl#L267

Added line #L267 was not covered by tests
else
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, deepcopy(u),
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u),

Check warning on line 269 in src/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/domain.jl#L269

Added line #L269 was not covered by tests
deepcopy(u))
end
condition = (u, t, integrator) -> true
CallbackSet(
ManifoldProjection(g; nlsolve = nlsolve, save = false,
autonomous = autonomous, nlopts = nlopts),
ManifoldProjection(
g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...),
DiscreteCallback(condition, affect!; save_positions = (false, save)))
end

Expand Down
181 changes: 102 additions & 79 deletions src/manifold.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,28 @@
Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
determine_chunksize(u, get_chunksize(alg))
end
Base.@pure function determine_chunksize(u, CS)
if CS != 0
return CS
else
return ForwardDiff.pickchunksize(length(u))
end
end

struct NLSOLVEJL_SETUP{CS, AD} end
Base.@pure function NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true)
NLSOLVEJL_SETUP{chunk_size, autodiff}()
end
(::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)
function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
AD ? autodiff = :forward : autodiff = :central
OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
end

# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{F, autonomous}
mutable struct NonAutonomousFunction{iip, F, autonomous}
f::F
t::Any
p::Any
end
(p::NonAutonomousFunction{F, true})(res, u) where {F} = p.f(res, u, p.p)
(p::NonAutonomousFunction{F, false})(res, u) where {F} = p.f(res, u, p.p, p.t)

(f::NonAutonomousFunction{true, F, true})(res, u, p) where {F} = f.f(res, u, p)

Check warning on line 7 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L7

Added line #L7 was not covered by tests
(f::NonAutonomousFunction{true, F, false})(res, u, p) where {F} = f.f(res, u, p, f.t)

(f::NonAutonomousFunction{false, F, true})(u, p) where {F} = f.f(u, p)
(f::NonAutonomousFunction{false, F, false})(u, p) where {F} = f.f(u, p, f.t)

Check warning on line 11 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L10-L11

Added lines #L10 - L11 were not covered by tests

SciMLBase.isinplace(::NonAutonomousFunction{iip}) where {iip} = iip

Check warning on line 13 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L13

Added line #L13 was not covered by tests

"""
```julia
ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true)
```

In many cases, you may want to declare a manifold on which a solution lives.
Mathematically, a manifold `M` is defined by a function `g` as the set of points
where `g(u)=0`. An embedded manifold can be a lower dimensional object which
constrains the solution. For example, `g(u)=E(u)-C` where `E` is the energy
of the system in state `u`, meaning that the energy must be constant (energy
preservation). Thus by defining the manifold the solution should live on, you
can retain desired properties of the solution.
ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)

In many cases, you may want to declare a manifold on which a solution lives. Mathematically,
a manifold `M` is defined by a function `g` as the set of points where `g(u) = 0`. An
embedded manifold can be a lower dimensional object which constrains the solution. For
example, `g(u) = E(u) - C` where `E` is the energy of the system in state `u`, meaning that
the energy must be constant (energy preservation). Thus by defining the manifold the
solution should live on, you can retain desired properties of the solution.

`ManifoldProjection` projects the solution of the differential equation to the chosen
manifold `g`, conserving a property while conserving the order. It is a consequence of
Expand All @@ -52,80 +34,121 @@

## Arguments

- `g`: The residual function for the manifold. This is an inplace function of form
`g(resid, u)` or `g(resid, u, p, t)` which writes to the residual `resid` the
difference from the manifold components. Here, it is assumed that `resid` is of
the same shape as `u`.
- `g`: The residual function for the manifold.

* This is an inplace function of form `g(resid, u, p)` or `g(resid, u, p, t)` which
writes to the residual `resid` the difference from the manifold components. Here, it
is assumed that `resid` is of the same shape as `u`.
* If `isinplace = Val(false)`, then `g` should be a function of the form `g(u, p)` or
`g(u, p, t)` which returns the residual.

## Keyword Arguments

- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
- `nlsolve`: A nonlinear solver as defined in the
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
- `save`: Whether to do the standard saving (applied after the callback)
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u)`.
- `nlopts`: Optional arguments to nonlinear solver which can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
- `nlls`: If the problem is a nonlinear least squares problem. `nlls = Val(false)`
generates a `NonlinearProblem` which is typically faster than
`NonlinearLeastSquaresProblem`, but is only applicable if the residual size is same as
the state size.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
`g(u, p)`. Specify it as `Val(::Bool)` to ensure this function call is type stable.
- `residual_prototype`: This needs to be specified if `nlls = Val(true)` and
`inplace = Val(true)` are specified together, else it is taken to be same as `u`.
- `kwargs`: All other keyword arguments are passed to
[NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).

### Saveat Warning

Note that the `ManifoldProjection` callback modifies the endpoints of the integration intervals
and thus breaks assumptions of internal interpolations. Because of this, the values for given by
saveat will not be order-matching. However, the interpolation error can be proportional to the
change by the projection, so if the projection is making small changes then one is still safe.
However, if there are large changes from each projection, you should consider only saving at
stopping/projection times. To do this, set `tstops` to the same values as `saveat`. There is a
performance hit by doing so because now the integrator is forced to stop at every saving point,
but this is guerenteed to match the order of the integrator even with the ManifoldProjection.
Note that the `ManifoldProjection` callback modifies the endpoints of the integration
intervals and thus breaks assumptions of internal interpolations. Because of this, the
values for given by saveat will not be order-matching. However, the interpolation error can
be proportional to the change by the projection, so if the projection is making small
changes then one is still safe. However, if there are large changes from each projection,
you should consider only saving at stopping/projection times. To do this, set `tstops` to
the same values as `saveat`. There is a performance hit by doing so because now the
integrator is forced to stop at every saving point, but this is guerenteed to match the
order of the integrator even with the ManifoldProjection.

## References

Ernst Hairer, Christian Lubich, Gerhard Wanner. Geometric Numerical Integration:
Structure-Preserving Algorithms for Ordinary Differential Equations. Berlin ;
New York :Springer, 2002.
"""
mutable struct ManifoldProjection{autonomous, F, NL, O}
mutable struct ManifoldProjection{iip, nlls, autonomous, F, NL, R, K}
g::F
nl_rhs::Any
nlcache::Any
nlsolve::NL
nlopts::O
resid_prototype::R
kwargs::K

function ManifoldProjection{autonomous}(g, nlsolve, nlopts) where {autonomous}
function ManifoldProjection{iip, nlls, autonomous}(
g, nlsolve, resid_prototype, kwargs) where {iip, nlls, autonomous}
# replace residual function if it is time-dependent
# since NLsolve only accepts functions with two arguments
_g = NonAutonomousFunction{typeof(g), autonomous}(g, 0, 0)
new{autonomous, typeof(_g), typeof(nlsolve), typeof(nlopts)}(_g, _g, nlsolve,
nlopts)
_g = NonAutonomousFunction{iip, typeof(g), autonomous}(g, 0)
return new{iip, nlls, autonomous, typeof(_g), typeof(nlsolve),
typeof(resid_prototype), typeof(kwargs)}(
_g, nothing, nlsolve, resid_prototype, kwargs)
end
end

# Now make `affect!` for this:
function (p::ManifoldProjection{autonomous, NL})(integrator) where {autonomous, NL}
function (p::ManifoldProjection{
iip, nlls, autonomous, NL})(integrator) where {iip, nlls,
autonomous, NL}
# update current time if residual function is time-dependent
if !autonomous
p.g.t = integrator.t
end
p.g.p = integrator.p
!autonomous && (p.g.t = integrator.t)

integrator.u .= p.nlsolve(p.nl_rhs, integrator.u; p.nlopts...)
end
# solve the nonlinear problem
reinit!(p.nlcache, integrator.u; p = integrator.p)
sol = solve!(p.nlcache)

function Manifold_initialize(cb, u::Number, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, [u])
u_modified!(integrator, false)
if !SciMLBase.successful_retcode(sol)
SciMLBase.terminate!(integrator, sol.retcode)
return
end

copyto!(integrator.u, sol.u)

Check warning on line 112 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L112

Added line #L112 was not covered by tests
end

function Manifold_initialize(cb, u, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, u)
return Manifold_initialize(cb.affect!, u, t, integrator)
end
function Manifold_initialize(
affect!::ManifoldProjection{iip, nlls}, u, t, integrator) where {iip, nlls}
nlfunc = NonlinearFunction{iip}(affect!.g; affect!.resid_prototype)
nlprob = if nlls
NonlinearLeastSquaresProblem(nlfunc, u, integrator.p)
else
NonlinearProblem(nlfunc, u, integrator.p)

Check warning on line 124 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L124

Added line #L124 was not covered by tests
end
affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
u_modified!(integrator, false)
end

function ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict{Symbol, Any}())
affect! = ManifoldProjection{autonomous}(g, nlsolve, nlopts)
# Since this is applied to every point, we can reasonably assume that the solution is close
# to the initial guess, so we would want to use NewtonRaphson / RobustMultiNewton instead of
# the default one.
function ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
# `nothing` is a valid solver, so this need to be `missing`
_nlls = SciMLBase._unwrap_val(nlls)
_nlsolve = nlsolve === missing ? (_nlls ? GaussNewton() : NewtonRaphson()) : nlsolve
iip = SciMLBase._unwrap_val(isinplace)
if autonomous === nothing
if iip
autonomous = maximum(SciMLBase.numargs(g)) == 3
else
autonomous = maximum(SciMLBase.numargs(g)) == 2

Check warning on line 144 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L144

Added line #L144 was not covered by tests
end
end
affect! = ManifoldProjection{iip, _nlls, SciMLBase._unwrap_val(autonomous)}(
g, _nlsolve, resid_prototype, kwargs)
condition = (u, t, integrator) -> true
save_positions = (false, save)
DiscreteCallback(condition, affect!;
initialize = Manifold_initialize,
save_positions = save_positions)
return DiscreteCallback(condition, affect!; initialize = Manifold_initialize,
save_positions = (false, save))
end

export ManifoldProjection
2 changes: 1 addition & 1 deletion src/saving.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ function linearize_period(t₀, t₁, u₀, u₁, integ, ilsc, caches, u_mask,
# Sanity check that we don't accidentally infinitely recurse
if t₁ - t₀ < dtmin
@debug("Linearization failure",
t₁, t₀, string(u₀), string(u₁), string(u_mask),dtmin)
t₁, t₀, string(u₀), string(u₁), string(u_mask), dtmin)
throw(ArgumentError("Linearization failed, fell below linearization subdivision threshold"))
end

Expand Down
Loading
Loading