Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisiting Boundary Value Problems #477

Merged
merged 16 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "1.98.1"
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
113 changes: 64 additions & 49 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ $(TYPEDEF)
"""
struct StandardBVProblem end

"""
$(TYPEDEF)
"""
struct TwoPointBVProblem end

@doc doc"""

Defines an BVP problem.
Expand All @@ -17,7 +22,7 @@ condition ``u_0`` which define an ODE:
\frac{du}{dt} = f(u,p,t)
```

along with an implicit function `bc!` which defines the residual equation, where
along with an implicit function `bc` which defines the residual equation, where

```math
bc(u,p,t) = 0
Expand All @@ -36,22 +41,27 @@ u(t_f) = b
### Constructors

```julia
TwoPointBVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
```

or if we have an initial guess function `initialGuess(t)` for the given BVP,
we can pass the initial guess to the problem constructors:

```julia
TwoPointBVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
```

For any BVP problem type, `bc!` is the inplace function:
For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be
out-of-place.

If the bvp is a StandardBVProblem (also known as a Multi-Point BV Problem) it must define
either of the following functions

```julia
bc!(residual, u, p, t)
residual = bc(u, p, t)
```

where `residual` computed from the current `u`. `u` is an array of solution values
Expand All @@ -61,6 +71,16 @@ time points, and for shooting type methods `u=sol` the ODE solution.
Note that all features of the `ODESolution` are present in this form.
In both cases, the size of the residual matches the size of the initial condition.

If the bvp is a TwoPointBVProblem it must define either of the following functions

```julia
bc!((resid_a, resid_b), (u_a, u_b), p)
resid_a, resid_b = bc((u_a, u_b), p)
```

where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
the solution values at the two endpoints, and `p` are the parameters.

Parameters are optional, and if not given, then a `NullParameters()` singleton
will be used which will throw nice errors if you try to index non-existent
parameters. Any extra keyword arguments are passed on to the solvers. For example,
Expand Down Expand Up @@ -88,16 +108,20 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
problem_type::PT
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
p = NullParameters(),
problem_type = StandardBVProblem();
kwargs...) where {iip}
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p),
typeof(f), typeof(f.bc),
typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p,
problem_type, kwargs)
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type,
kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
Expand All @@ -107,52 +131,43 @@ end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
end

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...)
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
end

"""
$(TYPEDEF)
"""
struct TwoPointBVPFunction{bF}
bc::bF
# This is mostly a fake stuct and isn't used anywhere
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
struct TwoPointBVPFunction{iip} end

@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
end
TwoPointBVPFunction(; bc = error("No argument bc")) = TwoPointBVPFunction(bc)
(f::TwoPointBVPFunction)(residual, ua, ub, p) = f.bc(residual, ua, ub, p)
(f::TwoPointBVPFunction)(residual, u, p) = f.bc(residual, u[1], u[end], p)

"""
$(TYPEDEF)
"""
struct TwoPointBVProblem{iip} end
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
TwoPointBVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...)
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
kwargs...)
end
function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
kwargs...) where {iip}
BVProblem{iip}(f, TwoPointBVPFunction(bc), u0, tspan, p; kwargs...)
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
end

# Allow previous timeseries solution
function TwoPointBVProblem(f::AbstractODEFunction,
bc,
sol::T,
tspan::Tuple,
p = NullParameters()) where {T <: AbstractTimeseriesSolution}
TwoPointBVProblem(f, bc, sol.u, tspan, p)
function TwoPointBVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple,
p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution}
return TwoPointBVProblem(f, bc, sol.u, tspan, p; kwargs...)
end
# Allow initial guess function for the initial guess
function TwoPointBVProblem(f::AbstractODEFunction,
bc,
initialGuess,
tspan::AbstractVector,
p = NullParameters();
kwargs...)
function TwoPointBVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
p = NullParameters(); kwargs...)
u0 = [initialGuess(i) for i in tspan]
TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
return TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p; kwargs...)
end
104 changes: 45 additions & 59 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2124,8 +2124,7 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2
"""
$(TYPEDEF)
"""
abstract type AbstractBVPFunction{iip} <:
AbstractDiffEqFunction{iip} end
abstract type AbstractBVPFunction{iip, twopoint} <: AbstractDiffEqFunction{iip} end

@doc doc"""
BVPFunction{iip,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip,specialize}
Expand Down Expand Up @@ -2230,11 +2229,9 @@ For more details on this argument, see the ODEFunction documentation.

The fields of the BVPFunction type directly match the names of the inputs.
"""
struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
BCJP, SP, TW, TWt,
TPJ,
S, S2, S3, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip}
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
JP, BCJP, BCRP, SP, TW, TWt, TPJ, S, S2, S3, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip, twopoint}
f::F
bc::BF
mass_matrix::TMM
Expand All @@ -2246,6 +2243,7 @@ struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
vjp::VJP
jac_prototype::JP
bcjac_prototype::BCJP
bcresid_prototype::BCRP
sparsity::SP
Wfact::TW
Wfact_t::TWt
Expand Down Expand Up @@ -3648,9 +3646,8 @@ function NonlinearFunction{iip, specialize}(f;
nothing,
sys = __has_sys(f) ? f.sys : nothing,
resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing) where {
iip,
specialize,
}
iip, specialize}

if mass_matrix === I && typeof(f) <: Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end
Expand Down Expand Up @@ -3814,35 +3811,28 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
cons_expr, sys)
end

function BVPFunction{iip, specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
I,
function BVPFunction{iip, specialize, twopoint}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ?
f.jac_prototype :
nothing,
bcjac_prototype = __has_jac_prototype(bc) ?
bc.jac_prototype :
nothing,
sparsity = __has_sparsity(f) ? f.sparsity :
jac_prototype,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
bcresid_prototype = nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = __has_syms(f) ? f.syms : nothing,
indepsym = __has_indepsym(f) ? f.indepsym : nothing,
paramsyms = __has_paramsyms(f) ? f.paramsyms :
nothing,
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing,
observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize}
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint}
if mass_matrix === I && typeof(f) <: Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end
Expand Down Expand Up @@ -3882,7 +3872,7 @@ function BVPFunction{iip, specialize}(f, bc;
_bccolorvec = bccolorvec
end

bciip = isinplace(bc, 4, "bc", iip)
bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
Expand All @@ -3892,66 +3882,62 @@ function BVPFunction{iip, specialize}(f, bc;
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip

nonconforming = (jaciip,
tgradiip,
jvpiip,
vjpiip,
Wfactiip,
Wfact_tiip,
nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
paramjaciip) .!= iip
bc_nonconforming = bcjaciip .!= bciip
if any(nonconforming)
nonconforming = findall(nonconforming)
functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
"paramjac"][nonconforming]
throw(NonconformingFunctionsError(functions))
end

if twopoint
if iip && (bcresid_prototype === nothing || length(bcresid_prototype) != 2)
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
end
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
end
end

if any(bc_nonconforming)
bc_nonconforming = findall(bc_nonconforming)
functions = ["bcjac"][bc_nonconforming]
throw(NonconformingFunctionsError(functions))
end

if specialize === NoSpecialize
BVPFunction{iip, specialize, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any, Any,
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, typeof(syms), typeof(indepsym), typeof(paramsyms),
Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix,
analytic,
tgrad,
jac, bcjac, jvp, vjp,
jac_prototype,
bcjac_prototype,
sparsity, Wfact,
Wfact_t,
paramjac, syms,
indepsym, paramsyms,
observed,
analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype,
bcjac_prototype, bcresid_prototype,
sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed,
_colorvec, _bccolorvec, sys)
else
BVPFunction{iip, specialize, typeof(f), typeof(bc), typeof(mass_matrix),
typeof(analytic),
typeof(tgrad),
typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms),
typeof(observed),
BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix),
typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp),
typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms),
typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic,
tgrad, jac, bcjac, jvp, vjp,
jac_prototype, bcjac_prototype, sparsity,
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
Wfact, Wfact_t, paramjac,
syms, indepsym, paramsyms, observed,
_colorvec, _bccolorvec, sys)
end
end

function BVPFunction{iip}(f, bc; kwargs...) where {iip}
BVPFunction{iip, FullSpecialize}(f, bc; kwargs...)
function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
end
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
function BVPFunction(f, bc; kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize}(f, bc; kwargs...)
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

Expand Down
2 changes: 1 addition & 1 deletion test/downstream/ensemble_bvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ tspan = (0.0, pi / 2)
p = [rand()]
bvp = BVProblem(ode!, bc!, initial_guess, tspan, p)
ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func)
sim = solve(ensemble_prob, GeneralMIRK4(), trajectories = 10, dt = 0.1)
sim = solve(ensemble_prob, MIRK4(), trajectories = 10, dt = 0.1)
Loading