Skip to content

Commit

Permalink
Finish ManifoldProjection
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 20, 2024
1 parent 2f56a5f commit c16671b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 73 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "3.0.0"
version = "2.38.0" # Make it 3.0.0 before releasing. This is needed to allow resolver to work for test dependencies

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down Expand Up @@ -49,6 +49,7 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -60,4 +61,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote"]
test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote", "NonlinearSolve"]
108 changes: 47 additions & 61 deletions src/manifold.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,28 @@
# Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
# determine_chunksize(u, get_chunksize(alg))
# end
# Base.@pure function determine_chunksize(u, CS)
# if CS != 0
# return CS
# else
# return ForwardDiff.pickchunksize(length(u))
# end
# end

# struct NLSOLVEJL_SETUP{CS, AD} end
# Base.@pure function NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true)
# NLSOLVEJL_SETUP{chunk_size, autodiff}()
# end
# (::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)
# function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
# AD ? autodiff = :forward : autodiff = :central
# OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
# ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
# end

# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{iip, F, autonomous}
f::F
t::Any
end

(f::NonAutonomousFunction{true, F, true})(res, u, p) where {F} = f.f(res, u, p)

Check warning on line 7 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L7

Added line #L7 was not covered by tests
(f::NonAutonomousFunction{true, F, false})(res, u, p) where {F} = f.f(res, u, p, p.t)
(f::NonAutonomousFunction{true, F, false})(res, u, p) where {F} = f.f(res, u, p, f.t)

(f::NonAutonomousFunction{false, F, true})(u, p) where {F} = f.f(u, p)
(f::NonAutonomousFunction{false, F, false})(u, p) where {F} = f.f(u, p, p.t)
(f::NonAutonomousFunction{false, F, false})(u, p) where {F} = f.f(u, p, f.t)

Check warning on line 11 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L10-L11

Added lines #L10 - L11 were not covered by tests

SciMLBase.isinplace(::NonAutonomousFunction{iip}) where {iip} = iip

Check warning on line 13 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L13

Added line #L13 was not covered by tests

"""
ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, nlopts = (;),
resid_prototype = nothing)
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
In many cases, you may want to declare a manifold on which a solution lives.
Mathematically, a manifold `M` is defined by a function `g` as the set of points
where `g(u) = 0`. An embedded manifold can be a lower dimensional object which
constrains the solution. For example, `g(u) = E(u) - C` where `E` is the energy
of the system in state `u`, meaning that the energy must be constant (energy
preservation). Thus by defining the manifold the solution should live on, you
can retain desired properties of the solution.
In many cases, you may want to declare a manifold on which a solution lives. Mathematically,
a manifold `M` is defined by a function `g` as the set of points where `g(u) = 0`. An
embedded manifold can be a lower dimensional object which constrains the solution. For
example, `g(u) = E(u) - C` where `E` is the energy of the system in state `u`, meaning that
the energy must be constant (energy preservation). Thus by defining the manifold the
solution should live on, you can retain desired properties of the solution.
`ManifoldProjection` projects the solution of the differential equation to the chosen
manifold `g`, conserving a property while conserving the order. It is a consequence of
Expand All @@ -57,67 +34,76 @@ properties.
## Arguments
- `g`: The residual function for the manifold. This is an inplace function of form
`g(resid, u)` or `g(resid, u, p, t)` which writes to the residual `resid` the
difference from the manifold components. Here, it is assumed that `resid` is of
the same shape as `u`.
- `g`: The residual function for the manifold.
* This is an inplace function of form `g(resid, u, p)` or `g(resid, u, p, t)` which
writes to the residual `resid` the difference from the manifold components. Here, it
is assumed that `resid` is of the same shape as `u`.
* If `isinplace = Val(false)`, then `g` should be a function of the form `g(u, p)` or
`g(u, p, t)` which returns the residual.
## Keyword Arguments
- `nlsolve`: A nonlinear solver as defined in the
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
- `save`: Whether to do the standard saving (applied after the callback)
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u)`.
- `nlopts`: Optional arguments to nonlinear solver which can be any of the
[NonlinearSolve.jl keywords](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).
- `nlls`: If the problem is a nonlinear least squares problem. `nlls = Val(false)`
generates a `NonlinearProblem` which is typically faster than
`NonlinearLeastSquaresProblem`, but is only applicable if the residual size is same as
the state size.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
`g(u, p)`. Specify it as `Val(::Bool)` to ensure this function call is type stable.
- `residual_prototype`: This needs to be specified if `nlls = Val(true)` and
`inplace = Val(true)` are specified together, else it is taken to be same as `u`.
- `kwargs`: All other keyword arguments are passed to
[NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).
### Saveat Warning
Note that the `ManifoldProjection` callback modifies the endpoints of the integration intervals
and thus breaks assumptions of internal interpolations. Because of this, the values for given by
saveat will not be order-matching. However, the interpolation error can be proportional to the
change by the projection, so if the projection is making small changes then one is still safe.
However, if there are large changes from each projection, you should consider only saving at
stopping/projection times. To do this, set `tstops` to the same values as `saveat`. There is a
performance hit by doing so because now the integrator is forced to stop at every saving point,
but this is guerenteed to match the order of the integrator even with the ManifoldProjection.
Note that the `ManifoldProjection` callback modifies the endpoints of the integration
intervals and thus breaks assumptions of internal interpolations. Because of this, the
values for given by saveat will not be order-matching. However, the interpolation error can
be proportional to the change by the projection, so if the projection is making small
changes then one is still safe. However, if there are large changes from each projection,
you should consider only saving at stopping/projection times. To do this, set `tstops` to
the same values as `saveat`. There is a performance hit by doing so because now the
integrator is forced to stop at every saving point, but this is guerenteed to match the
order of the integrator even with the ManifoldProjection.
## References
Ernst Hairer, Christian Lubich, Gerhard Wanner. Geometric Numerical Integration:
Structure-Preserving Algorithms for Ordinary Differential Equations. Berlin ;
New York :Springer, 2002.
"""
mutable struct ManifoldProjection{iip, nlls, autonomous, F, NL, NO, R}
mutable struct ManifoldProjection{iip, nlls, autonomous, F, NL, R, K}
g::F
nlcache::Any
nlsolve::NL
nlopts::NO
resid_prototype::R
kwargs::K

function ManifoldProjection{iip, nlls, autonomous}(
g, nlsolve, nlopts, resid_prototype) where {iip, nlls, autonomous}
g, nlsolve, resid_prototype, kwargs) where {iip, nlls, autonomous}
# replace residual function if it is time-dependent
_g = NonAutonomousFunction{iip, typeof(g), autonomous}(g, 0)
return new{iip, nlls, autonomous, typeof(_g), typeof(nlsolve),
typeof(nlopts), typeof(resid_prototype)}(
_g, nothing, nlsolve, nlopts, resid_prototype)
typeof(resid_prototype), typeof(kwargs)}(
_g, nothing, nlsolve, resid_prototype, kwargs)
end
end

# Now make `affect!` for this:
function (p::ManifoldProjection{iip, nlls, autonomous, NL})(integrator) where {iip, nlls,
autonomous, NL}
# update current time if residual function is time-dependent
if !autonomous
p.g.t = integrator.t
end
!autonomous && (p.g.t = integrator.t)

# solve the nonlinear problem
reinit!(p.nlcache, integrator.u; p = integrator.p)
sol = solve!(p.nlcache)

if !SciMLBase.successful_retcode(sol) && (nlls && sol.retcode != ReturnCode.Stalled)
if !SciMLBase.successful_retcode(sol)
SciMLBase.terminate!(integrator, sol.retcode)
return
end
Expand All @@ -136,16 +122,16 @@ function Manifold_initialize(
else
NonlinearProblem(nlfunc, u, integrator.p)

Check warning on line 123 in src/manifold.jl

View check run for this annotation

Codecov / codecov/patch

src/manifold.jl#L123

Added line #L123 was not covered by tests
end
affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.nlopts...)
affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
u_modified!(integrator, false)
end

# Since this is applied to every point, we can reasonably assume that the solution is close
# to the initial guess, so we would want to use NewtonRaphson / RobustMultiNewton instead of
# the default one.
function ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, nlopts = (;),
resid_prototype = nothing)
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
# `nothing` is a valid solver, so this need to be `missing`
_nlls = SciMLBase._unwrap_val(nlls)
_nlsolve = nlsolve === missing ? (_nlls ? GaussNewton() : NewtonRaphson()) : nlsolve
Expand All @@ -158,7 +144,7 @@ function ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
end
end
affect! = ManifoldProjection{iip, _nlls, autonomous}(
g, _nlsolve, nlopts, resid_prototype)
g, _nlsolve, resid_prototype, kwargs)
condition = (u, t, integrator) -> true
return DiscreteCallback(condition, affect!; initialize = Manifold_initialize,
save_positions = (false, save))
Expand Down
37 changes: 27 additions & 10 deletions test/manifold_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools
using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools, NonlinearSolve

u0 = ones(2, 2)
f = function (du, u, p, t)
Expand All @@ -23,29 +23,29 @@ sol = solve(prob, Vern7())
@test !(sol[end][1]^2 + sol[end][2]^2 2)

# autodiff=true
@inferred ManifoldProjection(g; autonomous = Val(true), resid_prototype = zeros(2))
@inferred ManifoldProjection(g; autonomous = Val(false), resid_prototype = zeros(2))
cb = ManifoldProjection(g; resid_prototype = zeros(2))
@test isautonomous(cb.affect!)
solve(prob, Vern7(), callback = cb)
@time sol = solve(prob, Vern7(), callback = cb)
@test sol[end][1]^2 + sol[end][2]^2 2

cb_t = ManifoldProjection(g_t)
cb_t = ManifoldProjection(g_t; resid_prototype = zeros(2))
@test !isautonomous(cb_t.affect!)
solve(prob, Vern7(), callback = cb_t)
@time sol_t = solve(prob, Vern7(), callback = cb_t)
@test sol_t.u == sol.u && sol_t.t == sol.t

# autodiff=false
cb_false = ManifoldProjection(g,
nlsolve = DiffEqCallbacks.NLSOLVEJL_SETUP(autodiff = false))
cb_false = ManifoldProjection(
g; nlsolve = GaussNewton(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
@test isautonomous(cb_false.affect!)
solve(prob, Vern7(), callback = cb_false)
sol = solve(prob, Vern7(), callback = cb_false)
@test sol[end][1]^2 + sol[end][2]^2 2

cb_t_false = ManifoldProjection(g_t,
nlsolve = DiffEqCallbacks.NLSOLVEJL_SETUP(autodiff = false))
nlsolve = GaussNewton(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
@test !isautonomous(cb_t_false.affect!)
solve(prob, Vern7(), callback = cb_t_false)
sol_t = solve(prob, Vern7(), callback = cb_t_false)
Expand All @@ -61,10 +61,27 @@ sol = solve(prob, Vern7(), callback = cb)
sol = solve(prob, Vern7(), callback = cb_t)
@test sol[end][1]^2 + sol[end][2]^2 2

# does not work since Calculus.jl (on which NLsolve.jl depends)
# implements only Jacobians of vectors
sol = solve(prob, Vern7(), callback = cb_false)
sol[end][1]^2 + sol[end][2]^2 2
@test sol[end][1]^2 + sol[end][2]^2 2

sol = solve(prob, Vern7(), callback = cb_t_false)
sol[end][1]^2 + sol[end][2]^2 2
@test sol[end][1]^2 + sol[end][2]^2 2

# Test termination if cannot project to manifold
function g_unsat(resid, u, p)
resid[1] = u[2]^2 + u[1]^2 - 1000
resid[2] = u[2]^2 + u[1]^2 - 20
end

cb_unsat = ManifoldProjection(g_unsat; resid_prototype = zeros(2))
sol = solve(prob, Vern7(), callback = cb_unsat)
@test !SciMLBase.successful_retcode(sol)
@test last(sol.t) != 100.0

# Tests for OOP Manifold Projection
function g_oop(u, p)
return [u[2]^2 + u[1]^2 - 2
u[3]^2 + u[4]^2 - 2]
end

g_t(resid, u, p, t) = g(resid, u, p)

0 comments on commit c16671b

Please sign in to comment.