diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index b28aa5422..d206ca07d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <: """ $(TYPEDEF) +Base for types defining integrand functions. +""" +abstract type AbstractIntegralFunction{iip} <: + AbstractSciMLFunction{iip} end + +""" +$(TYPEDEF) + Base for types defining optimization functions. """ abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize}, RODEFunction{iip, specialize}, NonlinearFunction{iip, specialize}, OptimizationFunction{iip, specialize}, - BVPFunction{iip, specialize}}) where {iip, + BVPFunction{iip, specialize}, + IntegralFunction{iip, specialize}, + BatchIntegralFunction{iip, specialize}}) where {iip, specialize} specialize end @@ -714,7 +724,6 @@ include("ensemble/ensemble_analysis.jl") include("solve.jl") include("interpolation.jl") -include("integrand_interface.jl") include("integrator_interface.jl") include("tabletraits.jl") include("remake.jl") @@ -788,7 +797,8 @@ export remake export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction, DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, - IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction + IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction, + IntegralFunction, BatchIntegralFunction export OptimizationFunction diff --git a/src/integrand_interface.jl b/src/integrand_interface.jl deleted file mode 100644 index 7b7ebc18e..000000000 --- a/src/integrand_interface.jl +++ /dev/null @@ -1,95 +0,0 @@ -""" - InplaceIntegrand(f!, result::AbstractArray) - -Constructor for a `InplaceIntegrand` accepting an integrand of the form `f!(y,x,p)`. The -caller also provides an output array needed to store the result of the quadrature. -Intermediate `y` arrays are allocated during the calculation, and the final result is is -written to `result`, so use the IntegralSolution immediately after the calculation to read -the result, and don't expect it to persist if the same integrand is used for another -calculation. -""" -struct InplaceIntegrand{F, T <: AbstractArray} - # in-place function f!(y, x, p) that takes one x value and outputs an array of results in-place - f!::F - I::T -end - -""" - BatchIntegrand(f!, y::AbstractArray, x::AbstractVector, max_batch=typemax(Int)) - -Constructor for a `BatchIntegrand` accepting an integrand of the form `f!(y,x,p) = y .= f!.(x, Ref(p))` -that can evaluate the integrand at multiple quadrature nodes using, for example, threads, -the GPU, or distributed-memory. The `max_batch` keyword is a soft limit on the number of -nodes passed to the integrand. The buffers `y,x` must both be `resize!`-able since the -number of evaluation points may vary between calls to `f!`. -""" -struct BatchIntegrand{F, Y, X} - # in-place function f!(y, x, p) that takes an array of x values and outputs an array of results in-place - f!::F - y::Y - x::X - max_batch::Int # maximum number of x to supply in parallel - function BatchIntegrand(f!, - y::AbstractVector, - x::AbstractVector, - max_batch::Integer = typemax(Int)) - max_batch > 0 || throw(ArgumentError("maximum batch size must be positive")) - return new{typeof(f!), typeof(y), typeof(x)}(f!, y, x, max_batch) - end -end - -""" - BatchIntegrand(f!, y, x; max_batch=typemax(Int)) - -Constructor for a `BatchIntegrand` with pre-allocated buffers. -""" -function BatchIntegrand(f!, y, x; max_batch::Integer = typemax(Int)) - BatchIntegrand(f!, y, x, max_batch) -end - -""" - BatchIntegrand(f!, y::Type, x::Type=Nothing; max_batch=typemax(Int)) - -Constructor for a `BatchIntegrand` whose range type is known. The domain type is optional. -Array buffers for those types are allocated internally. -""" -function BatchIntegrand(f!, Y::Type, X::Type = Nothing; max_batch::Integer = typemax(Int)) - BatchIntegrand(f!, Y[], X[], max_batch) -end - -""" - InplaceBatchIntegrand(f!, result::AbstractArray, y::AbstractArray, x::AbstractVector, max_batch=typemax(Int)) - -Constructor for a `InplaceBatchIntegrand` accepting an integrand of the form `f!(y,x,p) = y -.= f!.(x, Ref(p))` that can evaluate an inplace, array-valued integrand at multiple -quadrature nodes simultaneously using, for example, threads, the GPU, or distributed-memory. -The `max_batch` keyword is a soft limit on the number of nodes passed to the integrand. The -buffers `y,x` must both be `resize!`-able since the number of evaluation points may vary -between calls to `f!`. In particular, for a resizeable `y` buffer see ElasticArrays.jl . The -solution is written inplace to `result`. -""" -struct InplaceBatchIntegrand{F, T, Y, X} - # in-place function f!(y, x, p) that takes an array of x values and outputs an array of results in-place - f!::F - I::T - y::Y - x::X - max_batch::Int # maximum number of x to supply in parallel - function InplaceBatchIntegrand(f!, - I::AbstractArray, - y::AbstractArray, - x::AbstractVector, - max_batch::Integer = typemax(Int)) - max_batch > 0 || throw(ArgumentError("maximum batch size must be positive")) - return new{typeof(f!), typeof(I), typeof(y), typeof(x)}(f!, I, y, x, max_batch) - end -end - -""" - InplaceBatchIntegrand(f!, result, y, x; max_batch=typemax(Int)) - -Constructor for a `InplaceBatchIntegrand` with pre-allocated buffers. -""" -function InplaceBatchIntegrand(f!, result, y, x; max_batch::Integer = typemax(Int)) - InplaceBatchIntegrand(f!, result, y, x, max_batch) -end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 45d727074..a24cd5238 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2261,6 +2261,184 @@ end TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2 +@doc doc""" + IntegralFunction{iip,specialize,F,T,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` defined by: + +```math +f(u, p) +``` + +and its related functions, such as its Jacobian and gradient with respect to parameters. For +an in-place form of `f` see the `iip` section below for details on in-place or out-of-place +handling. + +```julia +IntegralFunction{iip,specialize}(f, [I]; + jac = __has_jac(f) ? f.jac : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + analytic = __has_analytic(f) ? f.analytic : nothing, + syms = __has_syms(f) ? f.syms : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + observed = __has_observed(f) ? f.observed : nothing) +``` + +Note that only `f` is required, and in the case of inplace integrands a mutable container +`I` to store the result of the integral. + +The remaining functions are optional and mainly used for accelerating the usage of `f`: +- `jac`: unused +- `paramjac`: unused +- `analytic`: unused +- `syms`: unused +- `jac_prototype`: unused +- `sparsity`: unused +- `colorvec`: unused +- `observed`: unused + +Since most arguments are unused, the following constructor provides the essential behavior: + +```julia +IntegralFunction(f, [I]; kws..) +``` + +If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be +out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `I` that is of the right type for the +final result of the integral, and the result is written to this container in-place. When +in-place forms are used, in-place array operations may be used by algorithms to reduce +allocations. If `I` is not provided, `f` is assumed to be out-of-place and quadrature is +performed assuming immutable return types. + +## specialize + +This field is currently unused + +## Fields + +The fields of the IntegralFunction type directly match the names of the inputs. +""" +struct IntegralFunction{iip, specialize, F, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <: + AbstractIntegralFunction{iip} + f::F + I::T + jac::TJ + paramjac::TPJ + analytic::Ta + syms::S + jac_prototype::JP + sparsity::SP + colorvec::TCV + observed::O +end + +TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 + +@doc doc""" +BatchIntegralFunction{iip,specialize,F,T,Y,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` that can be evaluated at multiple points simultaneously +using threads, the gpu, or distributed memory defined by: + +```math +f(y, u, p) +``` + +and its related functions, such as its Jacobian and gradient with respect to parameters. For +an in-place form of `f` see the `iip` section below for details on in-place or out-of-place +handling. + +``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose +output must be returned in the corresponding entries of ``y``. In general, the integration +algorithm is allowed to vary the number of evaluation points between subsequent calls to `f` + +```julia +BatchIntegralFunction{iip,specialize}(f, y, [I]; + max_batch = typemax(Int), + jac = __has_jac(f) ? f.jac : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + analytic = __has_analytic(f) ? f.analytic : nothing, + syms = __has_syms(f) ? f.syms : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + observed = __has_observed(f) ? f.observed : nothing) +``` + +Note that `f` is required and a `resize`-able buffer `y` to store the output, or range of +`f`, and in the case of inplace integrands a mutable container `I` to store the result of +the integral. These buffers can be reused across multiple compatible integrals to reduce +allocations. + +The keyword `max_batch` is used to set a soft limit on the number of points to batch at the +same time so that memory usage is controlled. + +The remaining functions are optional and mainly used for accelerating the usage of `f`: +- `jac`: unused +- `paramjac`: unused +- `analytic`: unused +- `syms`: unused +- `jac_prototype`: unused +- `sparsity`: unused +- `colorvec`: unused +- `observed`: unused + +Since most arguments are unused, the following constructor provides the essential behavior: + +```julia +BatchIntegralFunction(f, y, [I]; max_batch=typemax(Int), kws..) +``` + +If `I` is present, `f` is interpreted as in-place, and otherwise `f` is assumed to be +out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the +element type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers +or arrays), in-place functions must provide a container `I` that is of the right type for +the final result of the integral, and the result is written to this container in-place. When +`f` is in-place, the output buffer ``y`` is assumed to have a mutable element type, and the +last dimension of ``y`` should correspond to the batch index. For example, ``y`` would have +to be an `ElasticArray` or a `VectorOfSimilarArrays` of an `ElasticArray`. When in-place +forms are used, in-place array operations may be used by algorithms to reduce allocations. +If `I` is not provided, `f` is assumed to be out-of-place and quadrature is performed +assuming ``y`` is an `AbstractVector` with an immutable element type. + +## specialize + +This field is currently unused + +## Fields + +The fields of the BatchIntegralFunction type directly match the names of the inputs. +""" +struct BatchIntegralFunction{iip, specialize, F, Y, T, TJ, TPJ, Ta, S, JP, SP, TCV, O} <: + AbstractIntegralFunction{iip} + f::F + y::Y + I::T + max_batch::Int + jac::TJ + paramjac::TPJ + analytic::Ta + syms::S + jac_prototype::JP + sparsity::SP + colorvec::TCV + observed::O +end + +TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2 + ######### Backwards Compatibility Overloads (f::ODEFunction)(args...) = f.f(args...) @@ -3955,6 +4133,80 @@ function BVPFunction(f, bc; kwargs...) end BVPFunction(f::BVPFunction; kwargs...) = f +function IntegralFunction{iip, specialize}(f, I; + jac = __has_jac(f) ? f.jac : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + analytic = __has_analytic(f) ? f.analytic : nothing, + syms = __has_syms(f) ? f.syms : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize} + IntegralFunction{ + iip, + specialize, + typeof(f), + typeof(I), + typeof(jac), + typeof(paramjac), + typeof(analytic), + typeof(syms), + typeof(jac_prototype), + typeof(sparsity), + typeof(colorvec), + typeof(observed), + }(f, + I, + jac, + paramjac, + analytic, + syms, + jac_prototype, + sparsity, + colorvec, + observed) +end + +function IntegralFunction{iip}(f, I; kws...) where {iip} + return IntegralFunction{iip, FullSpecialize}(f, I; kws...) +end +IntegralFunction(f; kws...) = IntegralFunction{false}(f, nothing; kws...) +IntegralFunction(f, I; kws...) = IntegralFunction{true}(f, I; kws...) + +function BatchIntegralFunction{iip, specialize}(f, y, I; + max_batch::Integer = typemax(Int), + jac = __has_jac(f) ? f.jac : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + analytic = __has_analytic(f) ? f.analytic : nothing, + syms = __has_syms(f) ? f.syms : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + observed = __has_observed(f) ? f.observed : nothing) where {iip, specialize} + BatchIntegralFunction{ + iip, + specialize, + typeof(f), + typeof(y), + typeof(I), + typeof(jac), + typeof(paramjac), + typeof(analytic), + typeof(syms), + typeof(jac_prototype), + typeof(sparsity), + typeof(colorvec), + typeof(observed), + }(f, y, I, max_batch, jac, paramjac, analytic, syms, jac_prototype, sparsity, colorvec, + observed) +end + +function BatchIntegralFunction{iip}(f, y, I; kws...) where {iip} + return BatchIntegralFunction{iip, FullSpecialize}(f, y, I; kws...) +end +BatchIntegralFunction(f, y; kws...) = BatchIntegralFunction{false}(f, y, nothing; kws...) +BatchIntegralFunction(f, y, I; kws...) = BatchIntegralFunction{true}(f, y, I; kws...) + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -4064,7 +4316,9 @@ for S in [:ODEFunction :NonlinearFunction :IntervalNonlinearFunction :IncrementingODEFunction - :BVPFunction] + :BVPFunction + :IntegralFunction + :BatchIntegralFunction] @eval begin function ConstructionBase.constructorof(::Type{<:$S{iip}}) where { iip, diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index c424a6f84..ccbac5bb8 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -601,3 +601,20 @@ BVPFunction(bfoop, bciip, vjp = bvjp) bvjp(du, u, v, p, t) = [1.0] BVPFunction(bfiip, bciip, vjp = bvjp) BVPFunction(bfoop, bciip, vjp = bvjp) + +# IntegralFunction + +ioop(u, p) = p * u +iiip(y, u, p) = y .= u * p + +IntegralFunction(ioop) +IntegralFunction(iiip, Float64[]) + +# BatchIntegralFunction + +boop(y, u, p) = y .= p .* u +biip(y, u, p) = y .= p .* u # this example is not realistic + +BatchIntegralFunction(boop, Float64[]) +BatchIntegralFunction(boop, Float64[], max_batch = 20) +BatchIntegralFunction(biip, Float64[], Float64[]) # the 2nd argument should be an ElasticArray