diff --git a/Project.toml b/Project.toml index 492af0df1..b15c4569b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.0.0" +version = "1.99.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 011dfd690..d206ca07d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <: """ $(TYPEDEF) +Base for types defining integrand functions. +""" +abstract type AbstractIntegralFunction{iip} <: + AbstractSciMLFunction{iip} end + +""" +$(TYPEDEF) + Base for types defining optimization functions. """ abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize}, RODEFunction{iip, specialize}, NonlinearFunction{iip, specialize}, OptimizationFunction{iip, specialize}, - BVPFunction{iip, specialize}}) where {iip, + BVPFunction{iip, specialize}, + IntegralFunction{iip, specialize}, + BatchIntegralFunction{iip, specialize}}) where {iip, specialize} specialize end @@ -787,7 +797,8 @@ export remake export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction, DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, - IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction + IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction, + IntegralFunction, BatchIntegralFunction export OptimizationFunction diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index cd057eba5..aece5fa7b 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -335,26 +335,16 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`. ### Constructors ``` -IntegralProblem{iip}(f,lb,ub,p=NullParameters(); - nout=1, batch = 0, kwargs...) +IntegralProblem(f,domain,p=NullParameters(); kwargs...) +IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...) ``` -- f: the integrand, callable function `y = f(u,p)` for out-of-place or `f(y,u,p)` for in-place. +- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an + `IntegralFunction` or `BatchIntegralFunction` for inplace and batching optimizations. +- domain: an object representing an integration domain, i.e. the tuple `(lb, ub)`. - lb: Either a number or vector of lower bounds. - ub: Either a number or vector of upper bounds. - p: The parameters associated with the problem. -- nout: The output size of the function f. Defaults to 1, i.e., a scalar valued function. - If `nout > 1` f is a vector valued function . -- batch: The preferred number of points to batch. This allows user-side parallelization - of the integrand. If `batch == 0` no batching is performed. - If `batch > 0` both `u` and `y` get an additional dimension added to it. - This means that: - if `f` is a multi variable function each `u[:,i]` is a different point to evaluate `f` at, - if `f` is a single variable function each `u[i]` is a different point to evaluate `f` at, - if `f` is a vector valued function each `y[:,i]` is the evaluation of `f` at a different point, - if `f` is a scalar valued function `y[i]` is the evaluation of `f` at a different point. - Note that batch is a suggestion for the number of points, - and it is not necessarily true that batch is the same as batchsize in all algorithms. - kwargs: Keyword arguments copied to the solvers. Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at @@ -364,30 +354,58 @@ compile time whether the integrator function is in-place. The fields match the names of the constructor arguments. """ -struct IntegralProblem{isinplace, P, F, B, K} <: AbstractIntegralProblem{isinplace} +struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinplace} f::F - lb::B - ub::B - nout::Int + domain::T p::P - batch::Int kwargs::K - @add_kwonly function IntegralProblem{iip}(f, lb, ub, p = NullParameters(); - nout = 1, - batch = 0, kwargs...) where {iip} - @assert typeof(lb)==typeof(ub) "Type of lower and upper bound must match" + @add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain, + p = NullParameters(); + kwargs...) where {iip} warn_paramtype(p) - new{iip, typeof(p), typeof(f), typeof(lb), typeof(kwargs)}(f, - lb, ub, nout, p, - batch, kwargs) + new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f, + domain, p, kwargs) end end TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4 -function IntegralProblem(f, lb, ub, args...; +function IntegralProblem(f::AbstractIntegralFunction, + domain, + p = NullParameters(); kwargs...) - IntegralProblem{isinplace(f, 3)}(f, lb, ub, args...; kwargs...) + IntegralProblem{isinplace(f)}(f, domain, p; kwargs...) +end + +function IntegralProblem(f::AbstractIntegralFunction, + lb::B, + ub::B, + p = NullParameters(); + kwargs...) where {B} + IntegralProblem(f, (lb, ub), p; kwargs...) +end + +function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) + if nout !== nothing || batch !== nothing + @warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details." + end + + max_batch = batch === nothing ? 0 : batch + g = if isinplace(f, 3) + output_prototype = Vector{Float64}(undef, nout === nothing ? 1 : nout) + if max_batch == 0 + IntegralFunction(f, output_prototype) + else + BatchIntegralFunction(f, output_prototype, max_batch=max_batch) + end + else + if max_batch == 0 + IntegralFunction(f) + else + BatchIntegralFunction(f, max_batch=max_batch) + end + end + IntegralProblem(g, args...; kwargs...) end struct QuadratureProblem end @@ -405,8 +423,8 @@ Sampled integral problems are defined as: ```math \sum_i w_i y_i ``` -where `y_i` are sampled values of the integrand, and `w_i` are weights -assigned by a quadrature rule, which depend on sampling points `x`. +where `y_i` are sampled values of the integrand, and `w_i` are weights +assigned by a quadrature rule, which depend on sampling points `x`. ## Problem Type @@ -415,10 +433,10 @@ assigned by a quadrature rule, which depend on sampling points `x`. ``` SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim=ndims(y), kwargs...) ``` -- y: The sampled integrand, must be a subtype of `AbstractArray`. - It is assumed that the values of `y` along dimension `dim` +- y: The sampled integrand, must be a subtype of `AbstractArray`. + It is assumed that the values of `y` along dimension `dim` correspond to the integrand evaluated at sampling points `x` -- x: Sampling points, must be a subtype of `AbstractVector`. +- x: Sampling points, must be a subtype of `AbstractVector`. - dim: Dimension along which to integrate. Defaults to the last dimension of `y`. - kwargs: Keyword arguments copied to the solvers. @@ -434,7 +452,7 @@ struct SampledIntegralProblem{Y, X, K} <: AbstractIntegralProblem{false} @add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim = ndims(y), kwargs...) - @assert dim <= ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`" + @assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`" @assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension." @assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension." new{typeof(y), typeof(x), typeof(kwargs)}(y, x, dim, kwargs) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 50dfef684..c74808df2 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -203,6 +203,27 @@ function Base.showerror(io::IO, e::NonconformingFunctionsError) printstyled(io, e.nonconforming; bold = true, color = :red) end +const INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE = """ + Nonconforming functions detected. If an integrand function `f` is defined + as out-of-place (`f(u,p)`), then no integrand_prototype can be passed into the + function constructor. Likewise if `f` is defined as in-place (`f(out,u,p)`), then + an integrand_prototype is required. Either change the use of the function + constructor or define the appropriate dispatch for `f`. + """ + +struct IntegrandMismatchFunctionError <: Exception + iip::Bool + integrand_passed::Bool +end + +function Base.showerror(io::IO, e::IntegrandMismatchFunctionError) + println(io, INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE) + print(io, "Mismatch: IIP=") + printstyled(io, e.iip; bold = true, color = :red) + print(io, ", Integrand passed=") + printstyled(io, e.integrand_passed; bold = true, color = :red) +end + """ $(TYPEDEF) """ @@ -2259,11 +2280,122 @@ end TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2 +@doc doc""" + IntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` defined by: + +```math +f(u, p) +``` + +For an in-place form of `f` see the `iip` section below for details on in-place or +out-of-place handling. + +```julia +IntegralFunction{iip,specialize}(f, [integrand_prototype]) +``` + +Note that only `f` is required, and in the case of inplace integrands a mutable container +`integrand_prototype` to store the result of the integrand. If `integrand_prototype` is +present, `f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place functions must be of the form ``y = f(u, p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `integrand_prototype` that is of the +right type for the variable ``y``, and the result is written to this container in-place. +When in-place forms are used, in-place array operations, i.e. broadcasting, may be used by +algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed +to be out-of-place and quadrature is performed assuming immutable return types. + +## specialize + +This field is currently unused + +## Fields + +The fields of the IntegralFunction type directly match the names of the inputs. +""" +struct IntegralFunction{iip, specialize, F, T} <: + AbstractIntegralFunction{iip} + f::F + integrand_prototype::T +end + +TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2 + +@doc doc""" +BatchIntegralFunction{iip,specialize,F,T} <: AbstractIntegralFunction{iip} + +A representation of an integrand `f` that can be evaluated at multiple points simultaneously +using threads, the gpu, or distributed memory defined by: + +```math +y = f(u, p) +``` + +``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose +output must be returned as an array whose last "batching" dimension corresponds to integrand +evaluations at the different points in ``u``. In general, the integration algorithm is +allowed to vary the number of evaluation points between subsequent calls to `f`. + +For an in-place form of `f` see the `iip` section below for details on in-place or +out-of-place handling. + +```julia +BatchIntegralFunction{iip,specialize}(f, [integrand_prototype]; + max_batch=typemax(Int)) +``` +Note that only `f` is required, and in the case of inplace integrands a mutable container +`integrand_prototype` to store the result of the integrand of one integrand, without a last +"batching" dimension. + +The keyword `max_batch` is used to set a soft limit on the number of points to batch at the +same time so that memory usage is controlled. + +If `integrand_prototype` is present, `f` is interpreted as in-place, and otherwise `f` is +assumed to be out-of-place. + +## iip: In-Place vs Out-Of-Place + +Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form +``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or +arrays), in-place functions must provide a container `integrand_prototype` of the right type +for a single integrand evaluation. The integration algorithm will then allocate a ``y`` +array with the same element type as `integrand_prototype` and an additional last "batching" +dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm +may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y`` +is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of +batched points may vary between subsequent calls to `f`. When in-place forms are used, +in-place array operations may be used by algorithms to reduce allocations. If +`integrand_prototype` is not provided, `f` is assumed to be out-of-place. + +## specialize + +This field is currently unused + +## Fields + +The fields of the BatchIntegralFunction type directly match the names of the inputs. +""" +struct BatchIntegralFunction{iip, specialize, F, T} <: + AbstractIntegralFunction{iip} + f::F + integrand_prototype::T + max_batch::Int +end + +TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2 + ######### Backwards Compatibility Overloads (f::ODEFunction)(args...) = f.f(args...) (f::NonlinearFunction)(args...) = f.f(args...) (f::IntervalNonlinearFunction)(args...) = f.f(args...) +(f::IntegralFunction)(args...) = f.f(args...) +(f::BatchIntegralFunction)(args...) = f.f(args...) function (f::DynamicalODEFunction)(u, p, t) ArrayPartition(f.f1(u.x[1], u.x[2], p, t), f.f2(u.x[1], u.x[2], p, t)) @@ -3941,6 +4073,64 @@ function BVPFunction(f, bc; twopoint::Bool=false, kwargs...) end BVPFunction(f::BVPFunction; kwargs...) = f +function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} + IntegralFunction{iip, specialize, typeof(f), typeof(integrand_prototype)}(f, + integrand_prototype) +end + +function IntegralFunction{iip}(f, integrand_prototype) where {iip} + return IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) +end +function IntegralFunction(f) + calculated_iip = isinplace(f, 3, "integral", true) + if calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + IntegralFunction{false}(f, nothing) +end +function IntegralFunction(f, integrand_prototype) + calcuated_iip = isinplace(f, 3, "integral", true) + if !calcuated_iip + throw(IntegrandMismatchFunctionError(calcuated_iip, true)) + end + IntegralFunction{true}(f, integrand_prototype) +end + +function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; + max_batch::Integer = typemax(Int)) where {iip, specialize} + BatchIntegralFunction{ + iip, + specialize, + typeof(f), + typeof(integrand_prototype), + }(f, + integrand_prototype, + max_batch) +end + +function BatchIntegralFunction{iip}(f, + integrand_prototype; + kwargs...) where {iip} + return BatchIntegralFunction{iip, FullSpecialize}(f, + integrand_prototype; + kwargs...) +end + +function BatchIntegralFunction(f; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, false)) + end + BatchIntegralFunction{false}(f, nothing; kwargs...) +end +function BatchIntegralFunction(f, integrand_prototype; kwargs...) + calculated_iip = isinplace(f, 3, "batchintegral", true) + if !calculated_iip + throw(IntegrandMismatchFunctionError(calculated_iip, true)) + end + BatchIntegralFunction{true}(f, integrand_prototype; kwargs...) +end + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -4050,7 +4240,9 @@ for S in [:ODEFunction :NonlinearFunction :IntervalNonlinearFunction :IncrementingODEFunction - :BVPFunction] + :BVPFunction + :IntegralFunction + :BatchIntegralFunction] @eval begin function ConstructionBase.constructorof(::Type{<:$S{iip}}) where { iip, diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 1db0fb468..269985a39 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -614,4 +614,39 @@ bvjp(u, v, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) bvjp(du, u, v, p, t) = [1.0] BVPFunction(bfiip, bciip, vjp = bvjp) + @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) + +# IntegralFunction + +ioop(u, p) = p * u +iiip(y, u, p) = y .= u * p +i1(u) = u +itoo(y, u, p, a) = y .= u * p + +IntegralFunction(ioop) +IntegralFunction(iiip, Float64[]) + +@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[]) +@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(iiip) +@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1) +@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo) +@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo, Float64[]) + +# BatchIntegralFunction + +boop(u, p) = p .* u +biip(y, u, p) = y .= p .* u +bi1(u) = u +bitoo(y, u, p, a) = y .= p .* u + +BatchIntegralFunction(boop) +BatchIntegralFunction(boop, max_batch = 20) +BatchIntegralFunction(biip, Float64[]) +BatchIntegralFunction(biip, Float64[], max_batch = 20) + +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(boop, Float64[]) +@test_throws SciMLBase.IntegrandMismatchFunctionError BatchIntegralFunction(biip) +@test_throws SciMLBase.TooFewArgumentsError BatchIntegralFunction(bi1) +@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo) +@test_throws SciMLBase.TooManyArgumentsError BatchIntegralFunction(bitoo, Float64[])