Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add integrand interface #497

Merged
merged 28 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
57ba954
add integrand interface
lxvm Sep 16, 2023
21b7895
add InplaceBatchIntegrand
lxvm Sep 16, 2023
6a2038a
format and include
lxvm Sep 17, 2023
3f77759
make the IntegralFunctions
lxvm Sep 19, 2023
f56d654
canonicalize
ChrisRackauckas Sep 19, 2023
e3ee453
Remove error checking on function definition of batch integral
ChrisRackauckas Sep 19, 2023
318e79f
add error test on incorrect integral function dispatches
ChrisRackauckas Sep 19, 2023
783b88e
argument amounts testing
ChrisRackauckas Sep 19, 2023
fcc7edb
some better utils checks
ChrisRackauckas Sep 19, 2023
5a37040
apply format
lxvm Sep 19, 2023
b02e470
fix integralfunction iip
lxvm Sep 19, 2023
95bdb1d
rename integrand_prototype to integral_prototype
lxvm Sep 19, 2023
5675e6f
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
774e4be
fix typos
ChrisRackauckas Sep 21, 2023
bbe691b
revert naming to integrand_prototype
lxvm Sep 21, 2023
c0f4062
wrap integrand with IntegralFunction in IntegralProblem
lxvm Sep 21, 2023
740576a
make integral functions callable
lxvm Sep 21, 2023
3f7d1fb
simplify IntegralProblem definition
lxvm Sep 21, 2023
0deeefb
update docstrings
lxvm Sep 21, 2023
e27965d
apply format
lxvm Sep 21, 2023
5be7d7a
remove output_prototype
lxvm Sep 21, 2023
8ebfe42
add deprecation method
lxvm Sep 21, 2023
e6a0547
Update src/problems/basic_problems.jl
ChrisRackauckas Sep 21, 2023
a3a09d4
Merge branch 'master' into integrands
ChrisRackauckas Sep 21, 2023
619ac07
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
83f933d
Update function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
7740dd4
fix default batch and dispatch
lxvm Sep 21, 2023
a6fd63a
Change version just to run tests
ChrisRackauckas Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <:
"""
$(TYPEDEF)

Base for types defining integrand functions.
"""
abstract type AbstractIntegralFunction{iip} <:
AbstractSciMLFunction{iip} end

"""
$(TYPEDEF)

Base for types defining optimization functions.
"""
abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end
Expand Down Expand Up @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
specialize
end
Expand Down Expand Up @@ -787,7 +797,8 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction,
IntegralFunction, BatchIntegralFunction

export OptimizationFunction

Expand Down
256 changes: 255 additions & 1 deletion src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,184 @@ end

TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2

@doc doc"""
IntegralFunction{iip,specialize,F,T,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: AbstractIntegralFunction{iip}

A representation of an integrand `f` defined by:

```math
f(u, p)
```

and its related functions, such as its Jacobian and gradient with respect to parameters. For
an in-place form of `f` see the `iip` section below for details on in-place or out-of-place
handling.

```julia
IntegralFunction{iip,specialize}(f, [I];
jac = __has_jac(f) ? f.jac : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
analytic = __has_analytic(f) ? f.analytic : nothing,
syms = __has_syms(f) ? f.syms : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
observed = __has_observed(f) ? f.observed : nothing)
```

Note that only `f` is required, and in the case of inplace integrands a mutable container
`I` to store the result of the integral.

The remaining functions are optional and mainly used for accelerating the usage of `f`:
- `jac`: unused
- `paramjac`: unused
- `analytic`: unused
- `syms`: unused
- `jac_prototype`: unused
- `sparsity`: unused
- `colorvec`: unused
- `observed`: unused
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those arguments can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks


Since most arguments are unused, the following constructor provides the essential behavior:

```julia
IntegralFunction(f, [I]; kws..)
```

If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be
out-of-place.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? This should just be auto-determined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, when passed it's iip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I or integral_prototype is presumably an array that could have any element type and dimension specific to an inplace f, so the caller is basically declaring the output of f by passing the container. The C libraries assume everything is Vector{Float64}, with length nout passed to IntegralProblem, so that could be a fallback, but I would prefer it as a type assertion in the C wrappers since the user is already taking their time to write something efficient and in-place.


## iip: In-Place vs Out-Of-Place

Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form
``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `I` that is of the right type for the
final result of the integral, and the result is written to this container in-place. When
in-place forms are used, in-place array operations may be used by algorithms to reduce
allocations. If `I` is not provided, `f` is assumed to be out-of-place and quadrature is
performed assuming immutable return types.

## specialize

This field is currently unused

## Fields

The fields of the IntegralFunction type directly match the names of the inputs.
"""
struct IntegralFunction{iip, specialize, F, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <:
AbstractIntegralFunction{iip}
f::F
I::T
jac::TJ
paramjac::TPJ
analytic::Ta
syms::S
jac_prototype::JP
sparsity::SP
colorvec::TCV
observed::O
end

TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2

@doc doc"""
BatchIntegralFunction{iip,specialize,F,T,Y,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: AbstractIntegralFunction{iip}

A representation of an integrand `f` that can be evaluated at multiple points simultaneously
using threads, the gpu, or distributed memory defined by:

```math
f(y, u, p)
```

and its related functions, such as its Jacobian and gradient with respect to parameters. For
an in-place form of `f` see the `iip` section below for details on in-place or out-of-place
handling.

``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose
output must be returned in the corresponding entries of ``y``. In general, the integration
algorithm is allowed to vary the number of evaluation points between subsequent calls to `f`

```julia
BatchIntegralFunction{iip,specialize}(f, y, [I];
max_batch = typemax(Int),
jac = __has_jac(f) ? f.jac : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
analytic = __has_analytic(f) ? f.analytic : nothing,
syms = __has_syms(f) ? f.syms : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
observed = __has_observed(f) ? f.observed : nothing)
```

Note that `f` is required and a `resize`-able buffer `y` to store the output, or range of
`f`, and in the case of inplace integrands a mutable container `I` to store the result of
the integral. These buffers can be reused across multiple compatible integrals to reduce
allocations.

The keyword `max_batch` is used to set a soft limit on the number of points to batch at the
same time so that memory usage is controlled.

The remaining functions are optional and mainly used for accelerating the usage of `f`:
- `jac`: unused
- `paramjac`: unused
- `analytic`: unused
- `syms`: unused
- `jac_prototype`: unused
- `sparsity`: unused
- `colorvec`: unused
- `observed`: unused

Since most arguments are unused, the following constructor provides the essential behavior:

```julia
BatchIntegralFunction(f, y, [I]; max_batch=typemax(Int), kws..)
```

If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be
out-of-place.

## iip: In-Place vs Out-Of-Place

Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the
element type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers
or arrays), in-place functions must provide a container `I` that is of the right type for
the final result of the integral, and the result is written to this container in-place. When
`f` is in-place, the output buffer ``y`` is assumed to have a mutable element type, and the
last dimension of ``y`` should correspond to the batch index. For example, ``y`` would have
to be an `ElasticArray` or a `VectorOfSimilarArrays` of an `ElasticArray`. When in-place
forms are used, in-place array operations may be used by algorithms to reduce allocations.
If `I` is not provided, `f` is assumed to be out-of-place and quadrature is performed
assuming ``y`` is an `AbstractVector` with an immutable element type.

## specialize

This field is currently unused

## Fields

The fields of the BatchIntegralFunction type directly match the names of the inputs.
"""
struct BatchIntegralFunction{iip, specialize, F, Y, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <:
AbstractIntegralFunction{iip}
f::F
y::Y
I::T
max_batch::Int
jac::TJ
paramjac::TPJ
analytic::Ta
syms::S
jac_prototype::JP
sparsity::SP
colorvec::TCV
observed::O
end

TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2

######### Backwards Compatibility Overloads

(f::ODEFunction)(args...) = f.f(args...)
Expand Down Expand Up @@ -3955,6 +4133,80 @@ function BVPFunction(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

function IntegralFunction{iip, specialize}(f, I;
jac = __has_jac(f) ? f.jac : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
analytic = __has_analytic(f) ? f.analytic : nothing,
syms = __has_syms(f) ? f.syms : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize}
IntegralFunction{
iip,
specialize,
typeof(f),
typeof(I),
typeof(jac),
typeof(paramjac),
typeof(analytic),
typeof(syms),
typeof(jac_prototype),
typeof(sparsity),
typeof(colorvec),
typeof(observed),
}(f,
I,
jac,
paramjac,
analytic,
syms,
jac_prototype,
sparsity,
colorvec,
observed)
end

function IntegralFunction{iip}(f, I; kws...) where {iip}
return IntegralFunction{iip, FullSpecialize}(f, I; kws...)
end
IntegralFunction(f; kws...) = IntegralFunction{false}(f, nothing; kws...)
IntegralFunction(f, I; kws...) = IntegralFunction{true}(f, I; kws...)

function BatchIntegralFunction{iip, specialize}(f, y, I;
max_batch::Integer = typemax(Int),
jac = __has_jac(f) ? f.jac : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
analytic = __has_analytic(f) ? f.analytic : nothing,
syms = __has_syms(f) ? f.syms : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize}
BatchIntegralFunction{
iip,
specialize,
typeof(f),
typeof(y),
typeof(I),
typeof(jac),
typeof(paramjac),
typeof(analytic),
typeof(syms),
typeof(jac_prototype),
typeof(sparsity),
typeof(colorvec),
typeof(observed),
}(f, y, I, max_batch, jac, paramjac, analytic, syms, jac_prototype, sparsity, colorvec,
observed)
end

function BatchIntegralFunction{iip}(f, y, I; kws...) where {iip}
return BatchIntegralFunction{iip, FullSpecialize}(f, y, I; kws...)
end
BatchIntegralFunction(f, y; kws...) = BatchIntegralFunction{false}(f, y, nothing; kws...)
BatchIntegralFunction(f, y, I; kws...) = BatchIntegralFunction{true}(f, y, I; kws...)

########## Existence Functions

# Check that field/property exists (may be nothing)
Expand Down Expand Up @@ -4064,7 +4316,9 @@ for S in [:ODEFunction
:NonlinearFunction
:IntervalNonlinearFunction
:IncrementingODEFunction
:BVPFunction]
:BVPFunction
:IntegralFunction
:BatchIntegralFunction]
@eval begin
function ConstructionBase.constructorof(::Type{<:$S{iip}}) where {
iip,
Expand Down
17 changes: 17 additions & 0 deletions test/function_building_error_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,20 @@ BVPFunction(bfoop, bciip, vjp = bvjp)
bvjp(du, u, v, p, t) = [1.0]
BVPFunction(bfiip, bciip, vjp = bvjp)
BVPFunction(bfoop, bciip, vjp = bvjp)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

# IntegralFunction

ioop(u, p) = p * u
iiip(y, u, p) = y .= u * p

IntegralFunction(ioop)
IntegralFunction(iiip, Float64[])

# BatchIntegralFunction

boop(y, u, p) = y .= p .* u
biip(y, u, p) = y .= p .* u # this example is not realistic

BatchIntegralFunction(boop, Float64[])
BatchIntegralFunction(boop, Float64[], max_batch = 20)
BatchIntegralFunction(biip, Float64[], Float64[]) # the 2nd argument should be an ElasticArray