Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add integrand interface #497

Merged
merged 28 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
57ba954
add integrand interface
lxvm Sep 16, 2023
21b7895
add InplaceBatchIntegrand
lxvm Sep 16, 2023
6a2038a
format and include
lxvm Sep 17, 2023
3f77759
make the IntegralFunctions
lxvm Sep 19, 2023
f56d654
canonicalize
ChrisRackauckas Sep 19, 2023
e3ee453
Remove error checking on function definition of batch integral
ChrisRackauckas Sep 19, 2023
318e79f
add error test on incorrect integral function dispatches
ChrisRackauckas Sep 19, 2023
783b88e
argument amounts testing
ChrisRackauckas Sep 19, 2023
fcc7edb
some better utils checks
ChrisRackauckas Sep 19, 2023
5a37040
apply format
lxvm Sep 19, 2023
b02e470
fix integralfunction iip
lxvm Sep 19, 2023
95bdb1d
rename integrand_prototype to integral_prototype
lxvm Sep 19, 2023
5675e6f
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
774e4be
fix typos
ChrisRackauckas Sep 21, 2023
bbe691b
revert naming to integrand_prototype
lxvm Sep 21, 2023
c0f4062
wrap integrand with IntegralFunction in IntegralProblem
lxvm Sep 21, 2023
740576a
make integral functions callable
lxvm Sep 21, 2023
3f7d1fb
simplify IntegralProblem definition
lxvm Sep 21, 2023
0deeefb
update docstrings
lxvm Sep 21, 2023
e27965d
apply format
lxvm Sep 21, 2023
5be7d7a
remove output_prototype
lxvm Sep 21, 2023
8ebfe42
add deprecation method
lxvm Sep 21, 2023
e6a0547
Update src/problems/basic_problems.jl
ChrisRackauckas Sep 21, 2023
a3a09d4
Merge branch 'master' into integrands
ChrisRackauckas Sep 21, 2023
619ac07
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
83f933d
Update function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
7740dd4
fix default batch and dispatch
lxvm Sep 21, 2023
a6fd63a
Change version just to run tests
ChrisRackauckas Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <:
"""
$(TYPEDEF)

Base for types defining integrand functions.
"""
abstract type AbstractIntegralFunction{iip} <:
AbstractSciMLFunction{iip} end

"""
$(TYPEDEF)

Base for types defining optimization functions.
"""
abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end
Expand Down Expand Up @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
specialize
end
Expand Down Expand Up @@ -787,7 +797,8 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction,
IntegralFunction, BatchIntegralFunction

export OptimizationFunction

Expand Down
88 changes: 53 additions & 35 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,16 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`.
### Constructors

```
IntegralProblem{iip}(f,lb,ub,p=NullParameters();
nout=1, batch = 0, kwargs...)
IntegralProblem(f,domain,p=NullParameters(); kwargs...)
IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...)
```

- f: the integrand, callable function `y = f(u,p)` for out-of-place or `f(y,u,p)` for in-place.
- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an
`IntegralFunction` or `BatchIntegralFunction` for inplace and batching optimizations.
- domain: an object representing an integration domain, i.e. the tuple `(lb, ub)`.
- lb: Either a number or vector of lower bounds.
- ub: Either a number or vector of upper bounds.
- p: The parameters associated with the problem.
- nout: The output size of the function f. Defaults to 1, i.e., a scalar valued function.
If `nout > 1` f is a vector valued function .
- batch: The preferred number of points to batch. This allows user-side parallelization
of the integrand. If `batch == 0` no batching is performed.
If `batch > 0` both `u` and `y` get an additional dimension added to it.
This means that:
if `f` is a multi variable function each `u[:,i]` is a different point to evaluate `f` at,
if `f` is a single variable function each `u[i]` is a different point to evaluate `f` at,
if `f` is a vector valued function each `y[:,i]` is the evaluation of `f` at a different point,
if `f` is a scalar valued function `y[i]` is the evaluation of `f` at a different point.
Note that batch is a suggestion for the number of points,
and it is not necessarily true that batch is the same as batchsize in all algorithms.
- kwargs: Keyword arguments copied to the solvers.

Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at
Expand All @@ -364,30 +354,58 @@ compile time whether the integrator function is in-place.

The fields match the names of the constructor arguments.
"""
struct IntegralProblem{isinplace, P, F, B, K} <: AbstractIntegralProblem{isinplace}
struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinplace}
f::F
lb::B
ub::B
nout::Int
domain::T
p::P
batch::Int
kwargs::K
@add_kwonly function IntegralProblem{iip}(f, lb, ub, p = NullParameters();
nout = 1,
batch = 0, kwargs...) where {iip}
@assert typeof(lb)==typeof(ub) "Type of lower and upper bound must match"
@add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain,
p = NullParameters();
kwargs...) where {iip}
warn_paramtype(p)
new{iip, typeof(p), typeof(f), typeof(lb), typeof(kwargs)}(f,
lb, ub, nout, p,
batch, kwargs)
Comment on lines -381 to -382
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a deprecation path where if nout and batch are supplied we throw a warning and just define the prototypes as Arrays appropriately sized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for bringing this up. While implementing it I realized that the BatchIntegralFunction can let the user pass a two-argument out-of-place form and a 3-argument in-place form, which was allowed before and would match IntegralFunction. Then I'll remove the output_prototype field, since we can still query the output type of an out-of-place BatchIntegralFunction by calling the function on an empty vector of input points. The details of allocating an output_prototype may differ across libraries, but we have a mechanism to get the output type for both iip and oop forms, so this buffer can be correctly allocated by the solver, and solves our previous issue.

new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f,
domain, p, kwargs)
end
end

TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4

function IntegralProblem(f, lb, ub, args...;
function IntegralProblem(f::AbstractIntegralFunction,
domain,
p = NullParameters();
kwargs...)
IntegralProblem{isinplace(f, 3)}(f, lb, ub, args...; kwargs...)
IntegralProblem{isinplace(f)}(f, domain, p; kwargs...)
end

function IntegralProblem(f::AbstractIntegralFunction,
lb::B,
ub::B,
p = NullParameters();
kwargs...) where {B}
IntegralProblem(f, (lb, ub), p; kwargs...)
end

# deprecation methods, which assume integrands return Float64 values (same as C libraries)
function IntegralProblem{iip}(f, args...; nout = 1, batch = 0, kwargs...) where {iip}
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s"
g = if iip
output_prototype = Vector{Float64}(undef, nout)
if batch == 0
IntegralFunction(f, output_prototype)
else
BatchIntegralFunction(f, output_prototype, max_batch=batch)
end
else
if batch == 0
IntegralFunction(f)
else
BatchIntegralFunction(f, max_batch=batch)
end
end
IntegralProblem(g, args...; kwargs...)
end
function IntegralProblem(f, args...; kwargs...)
IntegralProblem{isinplace(f, 3)}(f, args...; kwargs...)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

struct QuadratureProblem end
Expand All @@ -405,8 +423,8 @@ Sampled integral problems are defined as:
```math
\sum_i w_i y_i
```
where `y_i` are sampled values of the integrand, and `w_i` are weights
assigned by a quadrature rule, which depend on sampling points `x`.
where `y_i` are sampled values of the integrand, and `w_i` are weights
assigned by a quadrature rule, which depend on sampling points `x`.

## Problem Type

Expand All @@ -415,10 +433,10 @@ assigned by a quadrature rule, which depend on sampling points `x`.
```
SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim=ndims(y), kwargs...)
```
- y: The sampled integrand, must be a subtype of `AbstractArray`.
It is assumed that the values of `y` along dimension `dim`
- y: The sampled integrand, must be a subtype of `AbstractArray`.
It is assumed that the values of `y` along dimension `dim`
correspond to the integrand evaluated at sampling points `x`
- x: Sampling points, must be a subtype of `AbstractVector`.
- x: Sampling points, must be a subtype of `AbstractVector`.
- dim: Dimension along which to integrate. Defaults to the last dimension of `y`.
- kwargs: Keyword arguments copied to the solvers.

Expand All @@ -434,7 +452,7 @@ struct SampledIntegralProblem{Y, X, D, K} <: AbstractIntegralProblem{false}
@add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector;
dim = ndims(y),
kwargs...)
@assert dim <= ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
@assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
@assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension."
@assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension."
new{typeof(y), typeof(x), Val{dim}, typeof(kwargs)}(y, x, Val(dim), kwargs)
Expand Down
Loading