diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index 96f83ddcf..561fbba6a 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -448,7 +448,7 @@ struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinpla p::P kwargs::K @add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain, - p = NullParameters(); + p = NullParameters(); nout = nothing, batch = nothing, kwargs...) where {iip} warn_paramtype(p) new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f, @@ -465,17 +465,18 @@ function IntegralProblem(f::AbstractIntegralFunction, IntegralProblem{isinplace(f)}(f, domain, p; kwargs...) end -@deprecate IntegralProblem(f::AbstractIntegralFunction, +@deprecate IntegralProblem{iip}(f::AbstractIntegralFunction, lb::Union{Number,AbstractVector{<:Number}}, ub::Union{Number,AbstractVector{<:Number}}, - p = NullParameters(); kwargs...) IntegralProblem(f, (lb, ub), p; kwargs...) + p = NullParameters(); kwargs...) where {iip} IntegralProblem{iip}(f, (lb, ub), p; kwargs...) -function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) +IntegralProblem(f, args...; kwargs...) = IntegralProblem{isinplace(f, 3)}(f, args...; kwargs...) +function IntegralProblem{iip}(f, args...; nout = nothing, batch = nothing, kwargs...) where {iip} if nout !== nothing || batch !== nothing @warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details." end - g = if isinplace(f, 3) + g = if iip if batch === nothing output_prototype = nout === nothing ? Array{Float64, 0}(undef) : Vector{Float64}(undef, nout) IntegralFunction(f, output_prototype) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 1027aaea3..604b1b85d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2088,12 +2088,12 @@ present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out- ## iip: In-Place vs Out-Of-Place Out-of-place functions must be of the form ``y = f(u, p)`` and in-place functions of the form -``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +``f(y, u, p)``, where `y` is a number or array containing the output. Since `f` is allowed to return any type (e.g. real or complex numbers or arrays), in-place functions must provide a container `integrand_prototype` that is of the -right type for the variable ``y``, and the result is written to this container in-place. +right type and size for the variable ``y``, and the result is written to this container in-place. When in-place forms are used, in-place array operations, i.e. broadcasting, may be used by algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed -to be out-of-place and quadrature is performed assuming immutable return types. +to be out-of-place. ## specialize @@ -2112,51 +2112,66 @@ end TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 @doc doc""" -BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} + BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} -A representation of an integrand `f` that can be evaluated at multiple points simultaneously -using threads, the gpu, or distributed memory defined by: +A batched representation of an (non-batched) integrand `f(u, p)` that can be +evaluated at multiple points simultaneously using threads, the gpu, or +distributed memory defined by: ```math -y = f(u, p) +by = bf(bu, p) ``` -``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose -output must be returned as an array whose last "batching" dimension corresponds to integrand -evaluations at the different points in ``u``. In general, the integration algorithm is -allowed to vary the number of evaluation points between subsequent calls to `f`. +Here we prefix variables with `b` to indicate they are batched variables, which +implies that they are arrays whose **last** dimension is reserved for batching +different evaluation points, or function values, and may be of a variable +length. ``bu`` is an array whose elements correspond to distinct evaluation +points to `f`, and `bf` is a function to evaluate `f` 'point-wise' so that +`f(bu[..., i], p) == bf(bu, p)[..., i]`. For example, a simple batching implementation +of a scalar, univariate function is via broadcasting: `bf(bu, p) = f.(bu, Ref(p))`, +although this interface exists in order to allow user parallelization. +In general, the integration algorithm is allowed to vary the number of +evaluation points between subsequent calls to `bf`. -For an in-place form of `f` see the `iip` section below for details on in-place or -out-of-place handling. +For an in-place form of `bf` see the `iip` section below for details on in-place +or out-of-place handling. ```julia -BatchIntegralFunction{iip,specialize}(f, [integrand_prototype]; +BatchIntegralFunction{iip,specialize}(bf, [integrand_prototype]; max_batch=typemax(Int)) ``` -Note that only `f` is required, and in the case of inplace integrands a mutable container -`integrand_prototype` to store a batch of integrand evaluations, with a last "batching" -dimension. +Note that only `bf` is required, and in the case of inplace integrands a mutable +container `integrand_prototype` to store a batch of integrand evaluations, with +a last "batching" dimension. -The keyword `max_batch` is used to set a soft limit on the number of points to batch at the -same time so that memory usage is controlled. +The keyword `max_batch` is used to set a soft limit on the number of points to +batch at the same time so that memory usage is controlled. -If `integrand_prototype` is present, `f` is interpreted as in-place, and otherwise `f` is -assumed to be out-of-place. +If `integrand_prototype` is present, `bf` is interpreted as in-place, and +otherwise `bf` is assumed to be out-of-place. ## iip: In-Place vs Out-Of-Place -Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form -``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or -arrays), in-place functions must provide a container `integrand_prototype` of the right type -for ``y``. The only assumption that is enforced is that the last axes of `the `y`` and ``u`` -arrays are the same length and correspond to distinct batched points. The algorithm will -then allocate arrays `similar` to ``y`` to pass to the integrand. Since the algorithm may -vary the number of points to batch, the length of the batching dimension of ``y`` may vary -between subsequent calls to `f`. To reduce allocations, views of ``y`` may also be passed to -the integrand. In the out-of-place case, the algorithm may infer the type -of ``y`` by passing `f` an empty array of input points. When in-place forms are used, -in-place array operations may be used by algorithms to reduce allocations. If -`integrand_prototype` is not provided, `f` is assumed to be out-of-place. +Out-of-place functions must be of the form `by = bf(bu, p)` and in-place +functions of the form `bf(by, bu, p)` where `by` is a batch array containing the +output. Since the algorithm may vary the number of points to batch, the batching +dimension can be of any length, including zero, and since `bf` is allowed to +return arrays of any type (e.g. real or complex) or size, in-place functions +must provide a container `integrand_prototype` of the desired type and size for +`by`. If `integrand_prototype` is not provided, `bf` is assumed to be +out-of-place. + +In the out-of-place case, we require `f(bu[..., i], p) == bf(bu, p)[..., i]`, +and certain algorithms, such as those implemented in C, may infer the type or +shape of `by` by calling `bf` with an empty array of input points, i.e. `bu` +with `size(bu)[end] == 0`. Then it is expected for the resulting `by` to have +the same type and `size(by)[begin:end-1]` for all subsequent calls. + +When the in-place form is used, we require `f(by[..., i], bu[..., i], p) == +bf(by, bu, p)[..., i]` and `size(by)[begin:end-1] == +size(integrand_prototype)[begin:end-1]`. The algorithm should always pass the +integrand `by` arrays that are `similar` to `integrand_prototype`, and may use +views and in-place array operations to reduce allocations. ## specialize @@ -2164,7 +2179,8 @@ This field is currently unused ## Fields -The fields of the BatchIntegralFunction type directly match the names of the inputs. +The fields of the BatchIntegralFunction type are `f`, corresponding to `bf` +above, and `integrand_prototype`. """ struct BatchIntegralFunction{iip, specialize, F, T} <: AbstractIntegralFunction{iip} diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 78eca9a90..5dad581cf 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -461,15 +461,27 @@ NonlinearFunction(nfiip, vjp = nvjp) NonlinearFunction(nfoop, vjp = nvjp) # Integrals -intf(u) = 1.0 -@test_throws SciMLBase.TooFewArgumentsError IntegralProblem(intf, (0.0, 1.0)) +intfew(u) = 1.0 +@test_throws SciMLBase.TooFewArgumentsError IntegralProblem(intfew, (0.0, 1.0)) +@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(intfew) +@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(intfew, zeros(3)) +@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(intfew) +@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(intfew, zeros(3)) intf(u, p) = 1.0 p = 2.0 - -IntegralProblem(intf, (0.0, 1.0)) -IntegralProblem(intf, (0.0, 1.0), p) -IntegralProblem(intf, ([0.0], [1.0])) -IntegralProblem(intf, ([0.0], [1.0]), p) +intfiip(y, u, p) = y .= 1.0 + +for (f, kws, iip) in ( + (intf, (;), false), + (IntegralFunction(intf), (;), false), + (intfiip, (; nout=3), true), + (IntegralFunction(intfiip, zeros(3)), (;), true), +), domain in (((0.0, 1.0),), (([0.0], [1.0]),), (0.0, 1.0), ([0.0], [1.0],)) + IntegralProblem(f, domain...; kws...) + IntegralProblem(f, domain..., p; kws...) + IntegralProblem{iip}(f, domain...; kws...) + IntegralProblem{iip}(f, domain..., p; kws...) +end x = [1.0, 2.0] y = rand(2, 2)