Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | struct EvalFunc{F} <: Function | ||
2 | f::F | ||
3 | end | ||
4 | (f::EvalFunc)(args...) = f.f(args...) | ||
5 | |||
6 | NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem, | ||
7 | AbstractIntegralProblem, AbstractSteadyStateProblem, | ||
8 | AbstractJumpProblem} | ||
9 | |||
10 | has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) | ||
11 | Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) | ||
12 | has_kwargs(::Type{T}) where {T} = __has_kwargs(T) | ||
13 | |||
14 | const allowedkeywords = (:dense, | ||
15 | :saveat, | ||
16 | :save_idxs, | ||
17 | :tstops, | ||
18 | :tspan, | ||
19 | :d_discontinuities, | ||
20 | :save_everystep, | ||
21 | :save_on, | ||
22 | :save_start, | ||
23 | :save_end, | ||
24 | :initialize_save, | ||
25 | :adaptive, | ||
26 | :abstol, | ||
27 | :reltol, | ||
28 | :dt, | ||
29 | :dtmax, | ||
30 | :dtmin, | ||
31 | :force_dtmin, | ||
32 | :internalnorm, | ||
33 | :controller, | ||
34 | :gamma, | ||
35 | :beta1, | ||
36 | :beta2, | ||
37 | :qmax, | ||
38 | :qmin, | ||
39 | :qsteady_min, | ||
40 | :qsteady_max, | ||
41 | :qoldinit, | ||
42 | :failfactor, | ||
43 | :calck, | ||
44 | :alias_u0, | ||
45 | :maxiters, | ||
46 | :callback, | ||
47 | :isoutofdomain, | ||
48 | :unstable_check, | ||
49 | :verbose, | ||
50 | :merge_callbacks, | ||
51 | :progress, | ||
52 | :progress_steps, | ||
53 | :progress_name, | ||
54 | :progress_message, | ||
55 | :progress_id, | ||
56 | :timeseries_errors, | ||
57 | :dense_errors, | ||
58 | :weak_timeseries_errors, | ||
59 | :weak_dense_errors, | ||
60 | :wrap, | ||
61 | :calculate_error, | ||
62 | :initializealg, | ||
63 | :alg, | ||
64 | :save_noise, | ||
65 | :delta, | ||
66 | :seed, | ||
67 | :alg_hints, | ||
68 | :kwargshandle, | ||
69 | :trajectories, | ||
70 | :batch_size, | ||
71 | :sensealg, | ||
72 | :advance_to_tstop, | ||
73 | :stop_at_next_tstop, | ||
74 | :u0, | ||
75 | :p, | ||
76 | # These two are from the default algorithm handling | ||
77 | :default_set, | ||
78 | :second_time, | ||
79 | # This is for DiffEqDevTools | ||
80 | :prob_choice, | ||
81 | # Jump problems | ||
82 | :alias_jump, | ||
83 | # This is for copying/deepcopying noise in StochasticDiffEq | ||
84 | :alias_noise, | ||
85 | # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves | ||
86 | :batch, | ||
87 | # Shooting method in BVP needs to differentiate between these two categories | ||
88 | :nlsolve_kwargs, | ||
89 | :odesolve_kwargs, | ||
90 | # If Solvers which internally use linsolve | ||
91 | :linsolve_kwargs, | ||
92 | # Solvers internally using EnsembleProblem | ||
93 | :ensemblealg, | ||
94 | # Fine Grained Control of Tracing (Storing and Logging) during Solve | ||
95 | :show_trace, | ||
96 | :trace_level, | ||
97 | :store_trace, | ||
98 | # Termination condition for solvers | ||
99 | :termination_condition) | ||
100 | |||
101 | const KWARGWARN_MESSAGE = """ | ||
102 | Unrecognized keyword arguments found. | ||
103 | The only allowed keyword arguments to `solve` are: | ||
104 | $allowedkeywords | ||
105 | |||
106 | See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. | ||
107 | |||
108 | Set kwargshandle=KeywordArgError for an error message. | ||
109 | Set kwargshandle=KeywordArgSilent to ignore this message. | ||
110 | """ | ||
111 | |||
112 | const KWARGERROR_MESSAGE = """ | ||
113 | Unrecognized keyword arguments found. | ||
114 | The only allowed keyword arguments to `solve` are: | ||
115 | $allowedkeywords | ||
116 | |||
117 | See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. | ||
118 | """ | ||
119 | |||
120 | struct CommonKwargError <: Exception | ||
121 | kwargs::Any | ||
122 | end | ||
123 | |||
124 | function Base.showerror(io::IO, e::CommonKwargError) | ||
125 | println(io, KWARGERROR_MESSAGE) | ||
126 | notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) | ||
127 | unrecognized = collect(keys(e.kwargs))[notin] | ||
128 | print(io, "Unrecognized keyword arguments: ") | ||
129 | printstyled(io, unrecognized; bold = true, color = :red) | ||
130 | print(io, "\n\n") | ||
131 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
132 | end | ||
133 | |||
134 | @enum KeywordArgError KeywordArgWarn KeywordArgSilent | ||
135 | |||
136 | const INCOMPATIBLE_U0_MESSAGE = """ | ||
137 | Initial condition incompatible with functional form. | ||
138 | Detected an in-place function with an initial condition of type Number or SArray. | ||
139 | This is incompatible because Numbers cannot be mutated, i.e. | ||
140 | `x = 2.0; y = 2.0; x .= y` will error. | ||
141 | |||
142 | If using a immutable initial condition type, please use the out-of-place form. | ||
143 | I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. | ||
144 | |||
145 | If your differential equation function was defined with multiple dispatches and one is | ||
146 | in-place, then the automatic detection will choose in-place. In this case, override the | ||
147 | choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. | ||
148 | |||
149 | For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: | ||
150 | https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation | ||
151 | """ | ||
152 | |||
153 | struct IncompatibleInitialConditionError <: Exception end | ||
154 | |||
155 | function Base.showerror(io::IO, e::IncompatibleInitialConditionError) | ||
156 | print(io, INCOMPATIBLE_U0_MESSAGE) | ||
157 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
158 | end | ||
159 | |||
160 | const NO_DEFAULT_ALGORITHM_MESSAGE = """ | ||
161 | Default algorithm choices require DifferentialEquations.jl. | ||
162 | Please specify an algorithm (e.g., `solve(prob, Tsit5())` or | ||
163 | `init(prob, Tsit5())` for an ODE) or import DifferentialEquations | ||
164 | directly. | ||
165 | |||
166 | You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ | ||
167 | and its associated pages. | ||
168 | """ | ||
169 | |||
170 | struct NoDefaultAlgorithmError <: Exception end | ||
171 | |||
172 | function Base.showerror(io::IO, e::NoDefaultAlgorithmError) | ||
173 | print(io, NO_DEFAULT_ALGORITHM_MESSAGE) | ||
174 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
175 | end | ||
176 | |||
177 | const NO_TSPAN_MESSAGE = """ | ||
178 | No tspan is set in the problem or chosen in the init/solve call | ||
179 | """ | ||
180 | |||
181 | struct NoTspanError <: Exception end | ||
182 | |||
183 | function Base.showerror(io::IO, e::NoTspanError) | ||
184 | print(io, NO_TSPAN_MESSAGE) | ||
185 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
186 | end | ||
187 | |||
188 | const NAN_TSPAN_MESSAGE = """ | ||
189 | NaN tspan is set in the problem or chosen in the init/solve call. | ||
190 | Note that -Inf and Inf values are allowed in the timespan for solves | ||
191 | which are terminated via callbacks, however NaN values are not allowed | ||
192 | since the direction of time is undetermined. | ||
193 | """ | ||
194 | |||
195 | struct NaNTspanError <: Exception end | ||
196 | |||
197 | function Base.showerror(io::IO, e::NaNTspanError) | ||
198 | print(io, NAN_TSPAN_MESSAGE) | ||
199 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
200 | end | ||
201 | |||
202 | const NON_SOLVER_MESSAGE = """ | ||
203 | The arguments to solve are incorrect. | ||
204 | The second argument must be a solver choice, `solve(prob,alg)` | ||
205 | where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. | ||
206 | |||
207 | Please double check the arguments being sent to the solver. | ||
208 | |||
209 | You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ | ||
210 | and its associated pages. | ||
211 | """ | ||
212 | |||
213 | struct NonSolverError <: Exception end | ||
214 | |||
215 | function Base.showerror(io::IO, e::NonSolverError) | ||
216 | print(io, NON_SOLVER_MESSAGE) | ||
217 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
218 | end | ||
219 | |||
220 | const NOISE_SIZE_MESSAGE = """ | ||
221 | Noise sizes are incompatible. The expected number of noise terms in the defined | ||
222 | `noise_rate_prototype` does not match the number of noise terms in the defined | ||
223 | `AbstractNoiseProcess`. Please ensure that | ||
224 | size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). | ||
225 | |||
226 | Note: Noise process definitions require that users specify `u0`, and this value is | ||
227 | directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, | ||
228 | then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, | ||
229 | then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, | ||
230 | then the noise process must be a 4-dimensional process, for example | ||
231 | `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition | ||
232 | of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should | ||
233 | be double checked. | ||
234 | """ | ||
235 | |||
236 | struct NoiseSizeIncompatabilityError <: Exception | ||
237 | prototypesize::Int | ||
238 | noisesize::Int | ||
239 | end | ||
240 | |||
241 | function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) | ||
242 | println(io, NOISE_SIZE_MESSAGE) | ||
243 | println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") | ||
244 | println(io, "length(prob.noise.W[1]) = $(e.noisesize)") | ||
245 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
246 | end | ||
247 | |||
248 | const PROBSOLVER_PAIRING_MESSAGE = """ | ||
249 | Incompatible problem+solver pairing. | ||
250 | For example, this can occur if an ODE solver is passed with an SDEProblem. | ||
251 | Solvers are only capable of handling specific problem types. Please double | ||
252 | check that the chosen pairing is capable for handling the given problems. | ||
253 | """ | ||
254 | |||
255 | struct ProblemSolverPairingError <: Exception | ||
256 | prob::Any | ||
257 | alg::Any | ||
258 | end | ||
259 | |||
260 | function Base.showerror(io::IO, e::ProblemSolverPairingError) | ||
261 | println(io, PROBSOLVER_PAIRING_MESSAGE) | ||
262 | println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") | ||
263 | println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") | ||
264 | println(io, | ||
265 | "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") | ||
266 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
267 | end | ||
268 | |||
269 | function compatible_problem_types(prob, alg) | ||
270 | if alg isa AbstractODEAlgorithm | ||
271 | ODEProblem | ||
272 | elseif alg isa AbstractSDEAlgorithm | ||
273 | (SDEProblem, SDDEProblem) | ||
274 | elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg | ||
275 | DDEProblem | ||
276 | elseif alg isa AbstractDAEAlgorithm | ||
277 | DAEProblem | ||
278 | elseif alg isa AbstractSteadyStateAlgorithm | ||
279 | SteadyStateProblem | ||
280 | end | ||
281 | end | ||
282 | |||
283 | const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ | ||
284 | Incompatible solver + automatic differentiation pairing. | ||
285 | The chosen automatic differentiation algorithm requires the ability | ||
286 | for compiler transforms on the code which is only possible on pure-Julia | ||
287 | solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods | ||
288 | which require this ability include: | ||
289 | |||
290 | - Direct use of ForwardDiff.jl on the solver | ||
291 | - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` | ||
292 | sensealg choices for adjoint differentiation. | ||
293 | |||
294 | Either switch the choice of solver to a pure Julia method, or change the automatic | ||
295 | differentiation method to one that does not require such transformations. | ||
296 | |||
297 | For more details on automatic differentiation, adjoint, and sensitivity analysis | ||
298 | of differential equations, see the documentation page: | ||
299 | |||
300 | https://diffeq.sciml.ai/stable/analysis/sensitivity/ | ||
301 | """ | ||
302 | |||
303 | struct DirectAutodiffError <: Exception end | ||
304 | |||
305 | function Base.showerror(io::IO, e::DirectAutodiffError) | ||
306 | println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) | ||
307 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
308 | end | ||
309 | |||
310 | const NONCONCRETE_ELTYPE_MESSAGE = """ | ||
311 | Non-concrete element type inside of an `Array` detected. | ||
312 | Arrays with non-concrete element types, such as | ||
313 | `Array{Union{Float32,Float64}}`, are not supported by the | ||
314 | differential equation solvers. Anyways, this is bad for | ||
315 | performance so you don't want to be doing this! | ||
316 | |||
317 | If this was a mistake, promote the element types to be | ||
318 | all the same. If this was intentional, for example, | ||
319 | using Unitful.jl with different unit values, then use | ||
320 | an array type which has fast broadcast support for | ||
321 | heterogeneous values such as the ArrayPartition | ||
322 | from RecursiveArrayTools.jl. For example: | ||
323 | |||
324 | ```julia | ||
325 | using RecursiveArrayTools | ||
326 | x = ArrayPartition([1.0,2.0],[1f0,2f0]) | ||
327 | y = ArrayPartition([3.0,4.0],[3f0,4f0]) | ||
328 | x .+ y # fast, stable, and usable as u0 into DiffEq! | ||
329 | ``` | ||
330 | |||
331 | Element type: | ||
332 | """ | ||
333 | |||
334 | struct NonConcreteEltypeError <: Exception | ||
335 | eltype::Any | ||
336 | end | ||
337 | |||
338 | function Base.showerror(io::IO, e::NonConcreteEltypeError) | ||
339 | print(io, NONCONCRETE_ELTYPE_MESSAGE) | ||
340 | print(io, e.eltype) | ||
341 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
342 | end | ||
343 | |||
344 | const NONNUMBER_ELTYPE_MESSAGE = """ | ||
345 | Non-Number element type inside of an `Array` detected. | ||
346 | Arrays with non-number element types, such as | ||
347 | `Array{Array{Float64}}`, are not supported by the | ||
348 | solvers. | ||
349 | |||
350 | If you are trying to use an array of arrays structure, | ||
351 | look at the tools in RecursiveArrayTools.jl. For example: | ||
352 | |||
353 | If this was a mistake, promote the element types to be | ||
354 | all the same. If this was intentional, for example, | ||
355 | using Unitful.jl with different unit values, then use | ||
356 | an array type which has fast broadcast support for | ||
357 | heterogeneous values such as the ArrayPartition | ||
358 | from RecursiveArrayTools.jl. For example: | ||
359 | |||
360 | ```julia | ||
361 | using RecursiveArrayTools | ||
362 | u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) | ||
363 | u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) | ||
364 | ``` | ||
365 | |||
366 | are both initial conditions which would be compatible with | ||
367 | the solvers. Or use ComponentArrays.jl for more complex | ||
368 | nested structures. | ||
369 | |||
370 | Element type: | ||
371 | """ | ||
372 | |||
373 | struct NonNumberEltypeError <: Exception | ||
374 | eltype::Any | ||
375 | end | ||
376 | |||
377 | function Base.showerror(io::IO, e::NonNumberEltypeError) | ||
378 | print(io, NONNUMBER_ELTYPE_MESSAGE) | ||
379 | print(io, e.eltype) | ||
380 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
381 | end | ||
382 | |||
383 | const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ | ||
384 | Non-standard number type (i.e. not Float32, Float64, | ||
385 | ComplexF32, or ComplexF64) detected as the element type | ||
386 | for the initial condition or time span. These generic | ||
387 | number types are only compatible with the pure Julia | ||
388 | solvers which support generic programming, such as | ||
389 | OrdinaryDiffEq.jl. The chosen solver does not support | ||
390 | this functionality. Please double check that the initial | ||
391 | condition and time span types are correct, and check that | ||
392 | the chosen solver was correct. | ||
393 | """ | ||
394 | |||
395 | struct GenericNumberTypeError <: Exception | ||
396 | alg::Any | ||
397 | uType::Any | ||
398 | tType::Any | ||
399 | end | ||
400 | |||
401 | function Base.showerror(io::IO, e::GenericNumberTypeError) | ||
402 | println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) | ||
403 | println(io, "Solver: $(e.alg)") | ||
404 | println(io, "u0 type: $(e.uType)") | ||
405 | print(io, "Timespan type: $(e.tType)") | ||
406 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
407 | end | ||
408 | |||
409 | const COMPLEX_SUPPORT_ERROR_MESSAGE = """ | ||
410 | Complex number type (i.e. ComplexF32, or ComplexF64) | ||
411 | detected as the element type for the initial condition | ||
412 | with an algorithm that does not support complex numbers. | ||
413 | Please check that the initial condition type is correct. | ||
414 | If complex number support is needed, try different solvers | ||
415 | such as those from OrdinaryDiffEq.jl. | ||
416 | """ | ||
417 | |||
418 | struct ComplexSupportError <: Exception | ||
419 | alg::Any | ||
420 | end | ||
421 | |||
422 | function Base.showerror(io::IO, e::ComplexSupportError) | ||
423 | println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) | ||
424 | println(io, "Solver: $(e.alg)") | ||
425 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
426 | end | ||
427 | |||
428 | const COMPLEX_TSPAN_ERROR_MESSAGE = """ | ||
429 | Complex number type (i.e. ComplexF32, or ComplexF64) | ||
430 | detected as the element type for the independent variable | ||
431 | (i.e. time span). Please check that the tspan type is correct. | ||
432 | No solvers support complex time spans. If this is required, | ||
433 | please open an issue. | ||
434 | """ | ||
435 | |||
436 | struct ComplexTspanError <: Exception end | ||
437 | |||
438 | function Base.showerror(io::IO, e::ComplexTspanError) | ||
439 | println(io, COMPLEX_TSPAN_ERROR_MESSAGE) | ||
440 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
441 | end | ||
442 | |||
443 | const TUPLE_STATE_ERROR_MESSAGE = """ | ||
444 | Tuple type used as a state. Since a tuple does not have vector | ||
445 | properties, it will not work as a state type in equation solvers. | ||
446 | Instead, change your equation from using tuple constructors `()` | ||
447 | to static array constructors `SA[]`. For example, change: | ||
448 | |||
449 | ```julia | ||
450 | function ftup((a,b),p,t) | ||
451 | return b,-a | ||
452 | end | ||
453 | u0 = (1.0,2.0) | ||
454 | tspan = (0.0,1.0) | ||
455 | ODEProblem(ftup,u0,tspan) | ||
456 | ``` | ||
457 | |||
458 | to: | ||
459 | |||
460 | ```julia | ||
461 | using StaticArrays | ||
462 | function fsa(u,p,t) | ||
463 | SA[u[2],u[1]] | ||
464 | end | ||
465 | u0 = SA[1.0,2.0] | ||
466 | tspan = (0.0,1.0) | ||
467 | ODEProblem(ftup,u0,tspan) | ||
468 | ``` | ||
469 | |||
470 | This will be safer and fast for small ODEs. For more information, see: | ||
471 | https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays | ||
472 | """ | ||
473 | |||
474 | struct TupleStateError <: Exception end | ||
475 | |||
476 | function Base.showerror(io::IO, e::TupleStateError) | ||
477 | println(io, TUPLE_STATE_ERROR_MESSAGE) | ||
478 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
479 | end | ||
480 | |||
481 | const MASS_MATRIX_ERROR_MESSAGE = """ | ||
482 | Mass matrix size is incompatible with initial condition | ||
483 | sizing. The mass matrix must represent the `vec` | ||
484 | form of the initial condition `u0`, i.e. | ||
485 | `size(mm,1) == size(mm,2) == length(u)` | ||
486 | """ | ||
487 | |||
488 | struct IncompatibleMassMatrixError <: Exception | ||
489 | sz::Int | ||
490 | len::Int | ||
491 | end | ||
492 | |||
493 | function Base.showerror(io::IO, e::IncompatibleMassMatrixError) | ||
494 | println(io, MASS_MATRIX_ERROR_MESSAGE) | ||
495 | print(io, "size(prob.f.mass_matrix,1): ") | ||
496 | println(io, e.sz) | ||
497 | print(io, "length(u0): ") | ||
498 | println(e.len) | ||
499 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
500 | end | ||
501 | |||
502 | 119 (41 %) |
238 (83 %)
samples spent in init_call
119 (50 %) (incl.) when called from #init_up#33 line 553 119 (50 %) (incl.) when called from init_call line 502
119 (100 %)
samples spent calling
#init_call#30
function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
|
|
503 | kwargs...) | ||
504 | kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle | ||
505 | kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? | ||
506 | _prob.kwargs[:kwargshandle] : kwargshandle | ||
507 | |||
508 | if has_kwargs(_prob) | ||
509 | if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) | ||
510 | kwargs_temp = NamedTuple{ | ||
511 | Base.diff_names(Base._nt_names(values(kwargs)), | ||
512 | (:callback,))}(values(kwargs)) | ||
513 | callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], | ||
514 | values(kwargs).callback),)) | ||
515 | kwargs = merge(kwargs_temp, callbacks) | ||
516 | end | ||
517 | kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) | ||
518 | end | ||
519 | |||
520 | checkkwargs(kwargshandle; kwargs...) | ||
521 | |||
522 | if _prob isa Union{ODEProblem, DAEProblem} && isnothing(_prob.u0) | ||
523 | build_null_integrator(_prob, args...; kwargs...) | ||
524 | elseif hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && | ||
525 | _prob.f.f isa EvalFunc | ||
526 | Base.invokelatest(__init, _prob, args...; kwargs...)#::T | ||
527 | else | ||
528 | 119 (41 %) |
119 (100 %)
samples spent calling
__init
__init(_prob, args...; kwargs...)#::T
|
|
529 | end | ||
530 | end | ||
531 | |||
532 | 119 (41 %) |
119 (100 %)
samples spent calling
#init#31
function init(prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing,
|
|
533 | u0 = nothing, p = nothing, kwargs...) | ||
534 | if sensealg === nothing && haskey(prob.kwargs, :sensealg) | ||
535 | sensealg = prob.kwargs[:sensealg] | ||
536 | end | ||
537 | |||
538 | u0 = u0 !== nothing ? u0 : prob.u0 | ||
539 | p = p !== nothing ? p : prob.p | ||
540 | |||
541 | 119 (41 %) |
119 (100 %)
samples spent calling
init_up
init_up(prob, sensealg, u0, p, args...; kwargs...)
|
|
542 | end | ||
543 | |||
544 | function init(prob::AbstractJumpProblem, args...; kwargs...) | ||
545 | init_call(prob, args...; kwargs...) | ||
546 | end | ||
547 | |||
548 | 119 (41 %) |
119 (100 %)
samples spent calling
#init_up#33
function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...)
|
|
549 | alg = extract_alg(args, kwargs, prob.kwargs) | ||
550 | if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling | ||
551 | _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, | ||
552 | p = p, kwargs...) | ||
553 | 119 (41 %) |
119 (100 %)
samples spent calling
init_call
init_call(_prob, args...; kwargs...)
|
|
554 | else | ||
555 | _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) | ||
556 | _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) | ||
557 | check_prob_alg_pairing(_prob, alg) # alg for improved inference | ||
558 | if length(args) > 1 | ||
559 | init_call(_prob, _alg, Base.tail(args)...; kwargs...) | ||
560 | else | ||
561 | init_call(_prob, _alg; kwargs...) | ||
562 | end | ||
563 | end | ||
564 | end | ||
565 | |||
566 | 284 (99 %) |
568 (197 %)
samples spent in solve_call
284 (50 %) (incl.) when called from #solve_up#42 line 1047 284 (50 %) (incl.) when called from solve_call line 566
284 (100 %)
samples spent calling
#solve_call#34
function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
|
|
567 | kwargs...) | ||
568 | kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle | ||
569 | kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? | ||
570 | _prob.kwargs[:kwargshandle] : kwargshandle | ||
571 | |||
572 | if has_kwargs(_prob) | ||
573 | if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) | ||
574 | kwargs_temp = NamedTuple{ | ||
575 | Base.diff_names(Base._nt_names(values(kwargs)), | ||
576 | (:callback,))}(values(kwargs)) | ||
577 | callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], | ||
578 | values(kwargs).callback),)) | ||
579 | kwargs = merge(kwargs_temp, callbacks) | ||
580 | end | ||
581 | kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) | ||
582 | end | ||
583 | |||
584 | checkkwargs(kwargshandle; kwargs...) | ||
585 | if isdefined(_prob, :u0) | ||
586 | if _prob.u0 isa Array | ||
587 | if !isconcretetype(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0)) | ||
588 | throw(NonConcreteEltypeError(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0))) | ||
589 | end | ||
590 | |||
591 | if !(eltype(_prob.u0) <: Number) && !(eltype(_prob.u0) <: Enum) | ||
592 | # Allow Enums for FunctionMaps, make into a trait in the future | ||
593 | throw(NonNumberEltypeError(eltype(_prob.u0))) | ||
594 | end | ||
595 | end | ||
596 | |||
597 | if _prob.u0 === nothing | ||
598 | return build_null_solution(_prob, args...; kwargs...) | ||
599 | end | ||
600 | end | ||
601 | |||
602 | if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && | ||
603 | _prob.f.f isa EvalFunc | ||
604 | Base.invokelatest(__solve, _prob, args...; kwargs...)#::T | ||
605 | else | ||
606 | 284 (99 %) |
284 (100 %)
samples spent calling
__solve
__solve(_prob, args...; kwargs...)#::T
|
|
607 | end | ||
608 | end | ||
609 | |||
610 | mutable struct NullODEIntegrator{IIP, ProbType, T, SolType, F, P} <: | ||
611 | AbstractODEIntegrator{Nothing, IIP, Nothing, T} | ||
612 | du::Vector{Float64} | ||
613 | u::Vector{Float64} | ||
614 | t::T | ||
615 | prob::ProbType | ||
616 | sol::SolType | ||
617 | f::F | ||
618 | p::P | ||
619 | end | ||
620 | function build_null_integrator(prob::AbstractDEProblem, args...; | ||
621 | kwargs...) | ||
622 | sol = solve(prob, args...; kwargs...) | ||
623 | return NullODEIntegrator{isinplace(prob), typeof(prob), eltype(prob.tspan), typeof(sol), | ||
624 | typeof(prob.f), typeof(prob.p), | ||
625 | }(Float64[], | ||
626 | Float64[], | ||
627 | prob.tspan[1], | ||
628 | prob, | ||
629 | sol, | ||
630 | prob.f, | ||
631 | prob.p) | ||
632 | end | ||
633 | function solve!(integ::NullODEIntegrator) | ||
634 | integ.t = integ.sol.t[end] | ||
635 | return nothing | ||
636 | end | ||
637 | function step!(integ::NullODEIntegrator, dt = nothing, stop_at_tdt = false) | ||
638 | if !isnothing(dt) | ||
639 | integ.t += dt | ||
640 | else | ||
641 | integ.t = integ.sol[end] | ||
642 | end | ||
643 | return nothing | ||
644 | end | ||
645 | |||
646 | function build_null_solution(prob::AbstractDEProblem, args...; | ||
647 | saveat = (), | ||
648 | save_everystep = true, | ||
649 | save_on = true, | ||
650 | save_start = save_everystep || isempty(saveat) || | ||
651 | saveat isa Number || prob.tspan[1] in saveat, | ||
652 | save_end = true, | ||
653 | kwargs...) | ||
654 | ts = if saveat === () | ||
655 | if save_start && save_end | ||
656 | [prob.tspan[1], prob.tspan[2]] | ||
657 | elseif save_start && !save_end | ||
658 | [prob.tspan[1]] | ||
659 | elseif !save_start && save_end | ||
660 | [prob.tspan[2]] | ||
661 | else | ||
662 | eltype(prob.tspan)[] | ||
663 | end | ||
664 | elseif saveat isa Number | ||
665 | prob.tspan[1]:saveat:prob.tspan[2] | ||
666 | else | ||
667 | saveat | ||
668 | end | ||
669 | |||
670 | timeseries = [Float64[] for i in 1:length(ts)] | ||
671 | |||
672 | build_solution(prob, nothing, ts, timeseries, retcode = ReturnCode.Success) | ||
673 | end | ||
674 | |||
675 | function build_null_solution(prob::Union{SteadyStateProblem, NonlinearProblem}, args...; | ||
676 | saveat = (), | ||
677 | save_everystep = true, | ||
678 | save_on = true, | ||
679 | save_start = save_everystep || isempty(saveat) || | ||
680 | saveat isa Number || prob.tspan[1] in saveat, | ||
681 | save_end = true, | ||
682 | kwargs...) | ||
683 | SciMLBase.build_solution(prob, nothing, Float64[], nothing; | ||
684 | retcode = ReturnCode.Success) | ||
685 | end | ||
686 | |||
687 | """ | ||
688 | ```julia | ||
689 | solve(prob::AbstractDEProblem, alg::Union{AbstractDEAlgorithm,Nothing}; kwargs...) | ||
690 | ``` | ||
691 | |||
692 | ## Arguments | ||
693 | |||
694 | The only positional argument is `alg` which is optional. By default, `alg = nothing`. | ||
695 | If `alg = nothing`, then `solve` dispatches to the DifferentialEquations.jl automated | ||
696 | algorithm selection (if `using DifferentialEquations` was done, otherwise it will | ||
697 | error with a `MethodError`). | ||
698 | |||
699 | ## Keyword Arguments | ||
700 | |||
701 | The DifferentialEquations.jl universe has a large set of common arguments available | ||
702 | for the `solve` function. These arguments apply to `solve` on any problem type and | ||
703 | are only limited by limitations of the specific implementations. | ||
704 | |||
705 | Many of the defaults depend on the algorithm or the package the algorithm derives | ||
706 | from. Not all of the interface is provided by every algorithm. | ||
707 | For more detailed information on the defaults and the available options | ||
708 | for specific algorithms / packages, see the manual pages for the solvers of specific | ||
709 | problems. To see whether a specific package is compatible with the use of a | ||
710 | given option, see the [Solver Compatibility Chart](https://docs.sciml.ai/DiffEqDocs/stable/basics/compatibility_chart/#Solver-Compatibility-Chart) | ||
711 | |||
712 | ### Default Algorithm Hinting | ||
713 | |||
714 | To help choose the default algorithm, the keyword argument `alg_hints` is | ||
715 | provided to `solve`. `alg_hints` is a `Vector{Symbol}` which describe the | ||
716 | problem at a high level to the solver. The options are: | ||
717 | |||
718 | * `:auto` vs `:nonstiff` vs `:stiff` - Denotes the equation as nonstiff/stiff. | ||
719 | `:auto` allow the default handling algorithm to choose stiffness detection | ||
720 | algorithms. The default handling defaults to using `:auto`. | ||
721 | |||
722 | Currently unused options include: | ||
723 | |||
724 | * `:interpolant` - Denotes that a high-precision interpolation is important. | ||
725 | * `:memorybound` - Denotes that the solver will be memory bound. | ||
726 | |||
727 | This functionality is derived via the benchmarks in | ||
728 | [SciMLBenchmarks.jl](https://github.com/SciML/SciMLBenchmarks.jl) | ||
729 | |||
730 | #### SDE Specific Alghints | ||
731 | |||
732 | * `:additive` - Denotes that the underlying SDE has additive noise. | ||
733 | * `:stratonovich` - Denotes that the solution should adhere to the Stratonovich | ||
734 | interpretation. | ||
735 | |||
736 | ### Output Control | ||
737 | |||
738 | These arguments control the output behavior of the solvers. It defaults to maximum | ||
739 | output to give the best interactive user experience, but can be reduced all the | ||
740 | way to only saving the solution at the final timepoint. | ||
741 | |||
742 | The following options are all related to output control. See the "Examples" | ||
743 | section at the end of this page for some example usage. | ||
744 | |||
745 | * `dense`: Denotes whether to save the extra pieces required for dense (continuous) | ||
746 | output. Default is `save_everystep && isempty(saveat)` for algorithms which have | ||
747 | the ability to produce dense output, i.e. by default it's `true` unless the user | ||
748 | has turned off saving on steps or has chosen a `saveat` value. If `dense=false`, | ||
749 | the solution still acts like a function, and `sol(t)` is a linear interpolation | ||
750 | between the saved time points. | ||
751 | * `saveat`: Denotes specific times to save the solution at, during the solving | ||
752 | phase. The solver will save at each of the timepoints in this array in the | ||
753 | most efficient manner available to the solver. If only `saveat` is given, then | ||
754 | the arguments `save_everystep` and `dense` are `false` by default. | ||
755 | If `saveat` is given a number, then it will automatically expand to | ||
756 | `tspan[1]:saveat:tspan[2]`. For methods where interpolation is not possible, | ||
757 | `saveat` may be equivalent to `tstops`. The default value is `[]`. | ||
758 | * `save_idxs`: Denotes the indices for the components of the equation to save. | ||
759 | Defaults to saving all indices. For example, if you are solving a 3-dimensional ODE, | ||
760 | and given `save_idxs = [1, 3]`, only the first and third components of the | ||
761 | solution will be outputted. | ||
762 | Notice that of course in this case the outputted solution will be two-dimensional. | ||
763 | * `tstops`: Denotes *extra* times that the timestepping algorithm must step to. | ||
764 | This should be used to help the solver deal with discontinuities and | ||
765 | singularities, since stepping exactly at the time of the discontinuity will | ||
766 | improve accuracy. If a method cannot change timesteps (fixed timestep | ||
767 | multistep methods), then `tstops` will use an interpolation, | ||
768 | matching the behavior of `saveat`. If a method cannot change timesteps and | ||
769 | also cannot interpolate, then `tstops` must be a multiple of `dt` or else an | ||
770 | error will be thrown. Default is `[]`. | ||
771 | * `d_discontinuities:` Denotes locations of discontinuities in low order derivatives. | ||
772 | This will force FSAL algorithms which assume derivative continuity to re-evaluate | ||
773 | the derivatives at the point of discontinuity. The default is `[]`. | ||
774 | * `save_everystep`: Saves the result at every step. | ||
775 | Default is true if `isempty(saveat)`. | ||
776 | * `save_on`: Denotes whether intermediate solutions are saved. This overrides the | ||
777 | settings of `dense`, `saveat` and `save_everystep` and is used by some applications | ||
778 | to manually turn off saving temporarily. Everyday use of the solvers should leave | ||
779 | this unchanged. Defaults to `true`. | ||
780 | * `save_start`: Denotes whether the initial condition should be included in | ||
781 | the solution type as the first timepoint. Defaults to `true`. | ||
782 | * `save_end`: Denotes whether the final timepoint is forced to be saved, | ||
783 | regardless of the other saving settings. Defaults to `true`. | ||
784 | * `initialize_save`: Denotes whether to save after the callback initialization | ||
785 | phase (when `u_modified=true`). Defaults to `true`. | ||
786 | |||
787 | Note that `dense` requires `save_everystep=true` and `saveat=false`. If you need | ||
788 | additional saving while keeping dense output, see | ||
789 | [the SavingCallback in the Callback Library](https://docs.sciml.ai/DiffEqCallbacks/stable/output_saving/#DiffEqCallbacks.SavingCallback). | ||
790 | |||
791 | ### Stepsize Control | ||
792 | |||
793 | These arguments control the timestepping routines. | ||
794 | |||
795 | #### Basic Stepsize Control | ||
796 | |||
797 | These are the standard options for controlling stepping behavior. Error estimates | ||
798 | do the comparison | ||
799 | |||
800 | ```math | ||
801 | err_{scaled} = err/(abstol + max(uprev,u)*reltol) | ||
802 | ``` | ||
803 | |||
804 | The scaled error is guaranteed to be `<1` for a given local error estimate | ||
805 | (note: error estimates are local unless the method specifies otherwise). `abstol` | ||
806 | controls the non-scaling error and thus can be thought of as the error around zero. | ||
807 | `reltol` scales with the size of the dependent variables and so one can interpret | ||
808 | `reltol=1e-3` as roughly being (locally) correct to 3 digits. Note tolerances can | ||
809 | be specified element-wise by passing a vector whose size matches `u0`. | ||
810 | |||
811 | * `adaptive`: Turns on adaptive timestepping for appropriate methods. Default | ||
812 | is true. | ||
813 | * `abstol`: Absolute tolerance in adaptive timestepping. This is the tolerance | ||
814 | on local error estimates, not necessarily the global error (though these quantities | ||
815 | are related). Defaults to `1e-6` on deterministic equations (ODEs/DDEs/DAEs) and `1e-2` | ||
816 | on stochastic equations (SDEs/RODEs). | ||
817 | * `reltol`: Relative tolerance in adaptive timestepping. This is the tolerance | ||
818 | on local error estimates, not necessarily the global error (though these quantities | ||
819 | are related). Defaults to `1e-3` on deterministic equations (ODEs/DDEs/DAEs) and `1e-2` | ||
820 | on stochastic equations (SDEs/RODEs). | ||
821 | * `dt`: Sets the initial stepsize. This is also the stepsize for fixed | ||
822 | timestep methods. Defaults to an automatic choice if the method is adaptive. | ||
823 | * `dtmax`: Maximum dt for adaptive timestepping. Defaults are | ||
824 | package-dependent. | ||
825 | * `dtmin`: Minimum dt for adaptive timestepping. Defaults are | ||
826 | package-dependent. | ||
827 | * `force_dtmin`: Declares whether to continue, forcing the minimum `dt` usage. | ||
828 | Default is `false`, which has the solver throw a warning and exit early when | ||
829 | encountering the minimum `dt`. Setting this true allows the solver to continue, | ||
830 | never letting `dt` go below `dtmin` (and ignoring error tolerances in those | ||
831 | cases). Note that `true` is not compatible with most interop packages. | ||
832 | |||
833 | #### Fixed Stepsize Usage | ||
834 | |||
835 | Note that if a method does not have adaptivity, the following rules apply: | ||
836 | |||
837 | * If `dt` is set, then the algorithm will step with size `dt` each iteration. | ||
838 | * If `tstops` and `dt` are both set, then the algorithm will step with either a | ||
839 | size `dt`, or use a smaller step to hit the `tstops` point. | ||
840 | * If `tstops` is set without `dt`, then the algorithm will step directly to | ||
841 | each value in `tstops` | ||
842 | * If neither `dt` nor `tstops` are set, the solver will throw an error. | ||
843 | |||
844 | #### [Advanced Adaptive Stepsize Control](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/) | ||
845 | |||
846 | These arguments control more advanced parts of the internals of adaptive timestepping | ||
847 | and are mostly used to make it more efficient on specific problems. For detained | ||
848 | explanations of the timestepping algorithms, see the | ||
849 | [timestepping descriptions](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/#timestepping) | ||
850 | |||
851 | * `internalnorm`: The norm function `internalnorm(u,t)` which error estimates | ||
852 | are calculated. Required are two dispatches: one dispatch for the state variable | ||
853 | and the other on the elements of the state variable (scalar norm). | ||
854 | Defaults are package-dependent. | ||
855 | * `controller`: Possible examples are [`IController`](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/#OrdinaryDiffEq.IController), | ||
856 | [`PIController`](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/#OrdinaryDiffEq.PIController), | ||
857 | [`PIDController`](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/#OrdinaryDiffEq.PIDController), | ||
858 | [`PredictiveController`](https://docs.sciml.ai/DiffEqDocs/stable/extras/timestepping/#OrdinaryDiffEq.PredictiveController). | ||
859 | Default is algorithm-dependent. | ||
860 | * `gamma`: The risk-factor γ in the q equation for adaptive timestepping | ||
861 | of the controllers using it. | ||
862 | Default is algorithm-dependent. | ||
863 | * `beta1`: The Lund stabilization α parameter. | ||
864 | Default is algorithm-dependent. | ||
865 | * `beta2`: The Lund stabilization β parameter. | ||
866 | Default is algorithm-dependent. | ||
867 | * `qmax`: Defines the maximum value possible for the adaptive q. | ||
868 | Default is algorithm-dependent. | ||
869 | * `qmin`: Defines the minimum value possible for the adaptive q. | ||
870 | Default is algorithm-dependent. | ||
871 | * `qsteady_min`: Defines the minimum for the range around 1 where the timestep | ||
872 | is held constant. Default is algorithm-dependent. | ||
873 | * `qsteady_max`: Defines the maximum for the range around 1 where the timestep | ||
874 | is held constant. Default is algorithm-dependent. | ||
875 | * `qoldinit`: The initial `qold` in stabilization stepping. | ||
876 | Default is algorithm-dependent. | ||
877 | * `failfactor`: The amount to decrease the timestep by if the Newton iterations | ||
878 | of an implicit method fail. Default is 2. | ||
879 | |||
880 | ### Memory Optimizations | ||
881 | |||
882 | * `calck`: Turns on and off the internal ability for intermediate | ||
883 | interpolations (also known as intermediate density). Not the same as `dense`, which is post-solution interpolation. | ||
884 | This defaults to `dense || !isempty(saveat) || "no custom callback is given"`. | ||
885 | This can be used to turn off interpolations | ||
886 | (to save memory) if one isn't using interpolations when a custom callback is | ||
887 | used. Another case where this may be used is to turn on interpolations for | ||
888 | usage in the integrator interface even when interpolations are used nowhere else. | ||
889 | Note that this is only required if the algorithm doesn't have | ||
890 | a free or lazy interpolation (`DP8()`). If `calck = false`, `saveat` cannot be used. | ||
891 | The rare keyword `calck` can be useful in event handling. | ||
892 | * `alias_u0`: allows the solver to alias the initial condition array that is contained | ||
893 | in the problem struct. Defaults to false. | ||
894 | |||
895 | ### Miscellaneous | ||
896 | |||
897 | * `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. | ||
898 | * `callback`: Specifies a callback. Defaults to a callback function which | ||
899 | performs the saving routine. For more information, see the | ||
900 | [Event Handling and Callback Functions manual page](https://docs.sciml.ai/DiffEqCallbacks/stable/). | ||
901 | * `isoutofdomain`: Specifies a function `isoutofdomain(u,p,t)` where, when it | ||
902 | returns true, it will reject the timestep. Disabled by default. | ||
903 | * `unstable_check`: Specifies a function `unstable_check(dt,u,p,t)` where, when | ||
904 | it returns true, it will cause the solver to exit and throw a warning. Defaults | ||
905 | to `any(isnan,u)`, i.e. checking if any value is a NaN. | ||
906 | * `verbose`: Toggles whether warnings are thrown when the solver exits early. | ||
907 | Defaults to true. | ||
908 | * `merge_callbacks`: Toggles whether to merge `prob.callback` with the `solve` keyword | ||
909 | argument `callback`. Defaults to `true`. | ||
910 | * `wrap`: Toggles whether to wrap the solution if `prob.problem_type` has a preferred | ||
911 | alternate wrapper type for the solution. Useful when speed, but not shape of solution | ||
912 | is important. Defaults to `Val(true)`. `Val(false)` will cancel wrapping the solution. | ||
913 | |||
914 | ### Progress Monitoring | ||
915 | |||
916 | These arguments control the usage of the progressbar in ProgressLogging.jl compatible environments. | ||
917 | |||
918 | * `progress`: Turns on/off the Juno progressbar. Default is false. | ||
919 | * `progress_steps`: Numbers of steps between updates of the progress bar. | ||
920 | Default is 1000. | ||
921 | * `progress_name`: Controls the name of the progressbar. Default is the name | ||
922 | of the problem type. | ||
923 | * `progress_message`: Controls the message with the progressbar. Defaults to | ||
924 | showing `dt`, `t`, the maximum of `u`. | ||
925 | * `progress_id`: Controls the ID of the progress log message to distinguish simultaneous simulations. | ||
926 | |||
927 | ### Error Calculations | ||
928 | |||
929 | If you are using the test problems (ex: `ODETestProblem`), then the following | ||
930 | options control the errors which are calculated: | ||
931 | |||
932 | * `timeseries_errors`: Turns on and off the calculation of errors at the steps | ||
933 | which were taken, such as the `l2` error. Default is true. | ||
934 | * `dense_errors`: Turns on and off the calculation of errors at the steps which | ||
935 | require dense output and calculate the error at 100 evenly-spaced points | ||
936 | throughout `tspan`. An example is the `L2` error. Default is false. | ||
937 | |||
938 | ### Sensitivity Algorithms (`sensealg`) | ||
939 | |||
940 | `sensealg` is used for choosing the way the automatic differentiation is performed. | ||
941 | For more information, see the documentation for SciMLSensitivity: | ||
942 | https://docs.sciml.ai/SciMLSensitivity/stable/ | ||
943 | |||
944 | ## Examples | ||
945 | |||
946 | The following lines are examples of how one could use the configuration of | ||
947 | `solve()`. For these examples a 3-dimensional ODE problem is assumed, however | ||
948 | the extension to other types is straightforward. | ||
949 | |||
950 | 1. `solve(prob, AlgorithmName())` : The "default" setting, with a user-specified | ||
951 | algorithm (given by `AlgorithmName()`). All parameters get their default values. | ||
952 | This means that the solution is saved at the steps the Algorithm stops internally | ||
953 | and dense output is enabled if the chosen algorithm allows for it. | ||
954 | |||
955 | All other integration parameters (e.g. stepsize) are chosen automatically. | ||
956 | 2. `solve(prob, saveat = 0.01, abstol = 1e-9, reltol = 1e-9)` : Standard setting | ||
957 | for accurate output at specified (and equidistant) time intervals, used for | ||
958 | e.g. Fourier Transform. The solution is given every 0.01 time units, | ||
959 | starting from `tspan[1]`. The solver used is `Tsit5()` since no keyword | ||
960 | `alg_hits` is given. | ||
961 | |||
962 | 3. `solve(prob, maxiters = 1e7, progress = true, save_idxs = [1])` : Using longer | ||
963 | maximum number of solver iterations can be useful when a given `tspan` is very | ||
964 | long. This example only saves the first of the variables of the system, either | ||
965 | to save size or because the user does not care about the others. Finally, with | ||
966 | `progress = true` you are enabling the progress bar. | ||
967 | """ | ||
968 | function solve(prob::AbstractDEProblem, args...; sensealg = nothing, | ||
969 | u0 = nothing, p = nothing, wrap = Val(true), kwargs...) | ||
970 | if sensealg === nothing && haskey(prob.kwargs, :sensealg) | ||
971 | sensealg = prob.kwargs[:sensealg] | ||
972 | end | ||
973 | |||
974 | u0 = u0 !== nothing ? u0 : prob.u0 | ||
975 | p = p !== nothing ? p : prob.p | ||
976 | |||
977 | if wrap isa Val{true} | ||
978 | wrap_sol(solve_up(prob, sensealg, u0, p, args...; kwargs...)) | ||
979 | else | ||
980 | solve_up(prob, sensealg, u0, p, args...; kwargs...) | ||
981 | end | ||
982 | end | ||
983 | |||
984 | """ | ||
985 | ```julia | ||
986 | solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...) | ||
987 | ``` | ||
988 | |||
989 | ## Arguments | ||
990 | |||
991 | The only positional argument is `alg` which is optional. By default, `alg = nothing`. | ||
992 | If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated | ||
993 | algorithm selection (if `using NonlinearSolve` was done, otherwise it will | ||
994 | error with a `MethodError`). | ||
995 | |||
996 | ## Keyword Arguments | ||
997 | |||
998 | The NonlinearSolve.jl universe has a large set of common arguments available | ||
999 | for the `solve` function. These arguments apply to `solve` on any problem type and | ||
1000 | are only limited by limitations of the specific implementations. | ||
1001 | |||
1002 | Many of the defaults depend on the algorithm or the package the algorithm derives | ||
1003 | from. Not all of the interface is provided by every algorithm. | ||
1004 | For more detailed information on the defaults and the available options | ||
1005 | for specific algorithms / packages, see the manual pages for the solvers of specific | ||
1006 | problems. | ||
1007 | |||
1008 | #### Error Control | ||
1009 | |||
1010 | * `abstol`: Absolute tolerance. | ||
1011 | * `reltol`: Relative tolerance. | ||
1012 | |||
1013 | ### Miscellaneous | ||
1014 | |||
1015 | * `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. | ||
1016 | * `verbose`: Toggles whether warnings are thrown when the solver exits early. | ||
1017 | Defaults to true. | ||
1018 | |||
1019 | ### Sensitivity Algorithms (`sensealg`) | ||
1020 | |||
1021 | `sensealg` is used for choosing the way the automatic differentiation is performed. | ||
1022 | For more information, see the documentation for SciMLSensitivity: | ||
1023 | https://docs.sciml.ai/SciMLSensitivity/stable/ | ||
1024 | """ | ||
1025 | 284 (99 %) |
568 (197 %)
samples spent in solve
284 (50 %) (incl.) when called from solve line 1025 284 (50 %) (incl.) when called from eval line 385
284 (100 %)
samples spent calling
#solve#41
function solve(prob::NonlinearProblem, args...; sensealg = nothing,
|
|
1026 | u0 = nothing, p = nothing, wrap = Val(true), kwargs...) | ||
1027 | if sensealg === nothing && haskey(prob.kwargs, :sensealg) | ||
1028 | sensealg = prob.kwargs[:sensealg] | ||
1029 | end | ||
1030 | |||
1031 | u0 = u0 !== nothing ? u0 : prob.u0 | ||
1032 | p = p !== nothing ? p : prob.p | ||
1033 | |||
1034 | if wrap isa Val{true} | ||
1035 | 284 (99 %) |
284 (100 %)
samples spent calling
solve_up
wrap_sol(solve_up(prob, sensealg, u0, p, args...; kwargs...))
|
|
1036 | else | ||
1037 | solve_up(prob, sensealg, u0, p, args...; kwargs...) | ||
1038 | end | ||
1039 | end | ||
1040 | |||
1041 | 284 (99 %) |
284 (100 %)
samples spent calling
#solve_up#42
function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
|
|
1042 | args...; kwargs...) | ||
1043 | alg = extract_alg(args, kwargs, prob.kwargs) | ||
1044 | if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling | ||
1045 | _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, | ||
1046 | p = p, kwargs...) | ||
1047 | 284 (99 %) |
284 (100 %)
samples spent calling
solve_call
solve_call(_prob, args...; kwargs...)
|
|
1048 | else | ||
1049 | _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) | ||
1050 | _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) | ||
1051 | check_prob_alg_pairing(_prob, alg) # use alg for improved inference | ||
1052 | if length(args) > 1 | ||
1053 | solve_call(_prob, _alg, Base.tail(args)...; kwargs...) | ||
1054 | else | ||
1055 | solve_call(_prob, _alg; kwargs...) | ||
1056 | end | ||
1057 | end | ||
1058 | end | ||
1059 | |||
1060 | function solve_call(prob::SteadyStateProblem, | ||
1061 | alg::SciMLBase.AbstractNonlinearAlgorithm, args...; | ||
1062 | kwargs...) | ||
1063 | solve_call(NonlinearProblem(prob), | ||
1064 | alg, args...; | ||
1065 | kwargs...) | ||
1066 | end | ||
1067 | |||
1068 | function solve(prob::EnsembleProblem, args...; kwargs...) | ||
1069 | alg = extract_alg(args, kwargs, kwargs) | ||
1070 | if length(args) > 1 | ||
1071 | __solve(prob, alg, Base.tail(args)...; kwargs...) | ||
1072 | else | ||
1073 | __solve(prob, alg; kwargs...) | ||
1074 | end | ||
1075 | end | ||
1076 | function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) | ||
1077 | SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) | ||
1078 | end | ||
1079 | function solve(prob::AbstractNoiseProblem, args...; kwargs...) | ||
1080 | __solve(prob, args...; kwargs...) | ||
1081 | end | ||
1082 | |||
1083 | function solve(prob::AbstractJumpProblem, args...; kwargs...) | ||
1084 | __solve(prob, args...; kwargs...) | ||
1085 | end | ||
1086 | |||
1087 | function checkkwargs(kwargshandle; kwargs...) | ||
1088 | if any(x -> x ∉ allowedkeywords, keys(kwargs)) | ||
1089 | if kwargshandle == KeywordArgError | ||
1090 | throw(CommonKwargError(kwargs)) | ||
1091 | elseif kwargshandle == KeywordArgWarn | ||
1092 | @warn KWARGWARN_MESSAGE | ||
1093 | unrecognized = setdiff(keys(kwargs), allowedkeywords) | ||
1094 | print("Unrecognized keyword arguments: ") | ||
1095 | printstyled(unrecognized; bold = true, color = :red) | ||
1096 | print("\n\n") | ||
1097 | else | ||
1098 | @assert kwargshandle == KeywordArgSilent | ||
1099 | end | ||
1100 | end | ||
1101 | end | ||
1102 | |||
1103 | function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) | ||
1104 | prob | ||
1105 | end | ||
1106 | |||
1107 | function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...) | ||
1108 | u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) | ||
1109 | u0 = promote_u0(u0, prob.p, nothing) | ||
1110 | p = get_concrete_p(prob, kwargs) | ||
1111 | remake(prob; u0 = u0, p = p) | ||
1112 | end | ||
1113 | |||
1114 | function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) | ||
1115 | u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) | ||
1116 | u0 = promote_u0(u0, prob.p, nothing) | ||
1117 | p = get_concrete_p(prob, kwargs) | ||
1118 | remake(prob; u0 = u0, p = p) | ||
1119 | end | ||
1120 | |||
1121 | function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) | ||
1122 | u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) | ||
1123 | u0 = promote_u0(u0, prob.p, nothing) | ||
1124 | p = get_concrete_p(prob, kwargs) | ||
1125 | remake(prob; u0 = u0, p = p) | ||
1126 | end | ||
1127 | |||
1128 | function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...) | ||
1129 | prob | ||
1130 | end | ||
1131 | |||
1132 | function solve(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; | ||
1133 | kwargs...) | ||
1134 | solve(prob.prob, alg, args...; kwargs...) | ||
1135 | end | ||
1136 | |||
1137 | function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; | ||
1138 | kwargs...) | ||
1139 | init(prob.prob, alg, args...; kwargs...) | ||
1140 | end | ||
1141 | |||
1142 | function get_concrete_problem(prob, isadapt; kwargs...) | ||
1143 | p = get_concrete_p(prob, kwargs) | ||
1144 | tspan = get_concrete_tspan(prob, isadapt, kwargs, p) | ||
1145 | u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs) | ||
1146 | u0_promote = promote_u0(u0, p, tspan[1]) | ||
1147 | tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) | ||
1148 | f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, | ||
1149 | tspan_promote[1]) | ||
1150 | if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && | ||
1151 | prob.tspan == tspan && typeof(prob.tspan) === typeof(tspan_promote) && | ||
1152 | p === prob.p && f_promote === prob.f | ||
1153 | return prob | ||
1154 | else | ||
1155 | return remake(prob; f = f_promote, u0 = u0_promote, p = p, tspan = tspan_promote) | ||
1156 | end | ||
1157 | end | ||
1158 | |||
1159 | function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...) | ||
1160 | p = get_concrete_p(prob, kwargs) | ||
1161 | tspan = get_concrete_tspan(prob, isadapt, kwargs, p) | ||
1162 | u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs) | ||
1163 | du0 = get_concrete_du0(prob, isadapt, tspan[1], kwargs) | ||
1164 | |||
1165 | u0_promote = promote_u0(u0, p, tspan[1]) | ||
1166 | du0_promote = promote_u0(du0, p, tspan[1]) | ||
1167 | tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) | ||
1168 | |||
1169 | f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, | ||
1170 | tspan_promote[1]) | ||
1171 | if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && | ||
1172 | isconcretedu0(prob, tspan[1], kwargs) && typeof(du0_promote) === typeof(prob.du0) && | ||
1173 | prob.tspan == tspan && typeof(prob.tspan) === typeof(tspan_promote) && | ||
1174 | p === prob.p && f_promote === prob.f | ||
1175 | return prob | ||
1176 | else | ||
1177 | return remake(prob; f = f_promote, du0 = du0_promote, u0 = u0_promote, p = p, | ||
1178 | tspan = tspan_promote) | ||
1179 | end | ||
1180 | end | ||
1181 | |||
1182 | function get_concrete_problem(prob::DDEProblem, isadapt; kwargs...) | ||
1183 | p = get_concrete_p(prob, kwargs) | ||
1184 | tspan = get_concrete_tspan(prob, isadapt, kwargs, p) | ||
1185 | u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs) | ||
1186 | |||
1187 | if prob.constant_lags isa Function | ||
1188 | constant_lags = prob.constant_lags(p) | ||
1189 | else | ||
1190 | constant_lags = prob.constant_lags | ||
1191 | end | ||
1192 | |||
1193 | u0 = promote_u0(u0, p, tspan[1]) | ||
1194 | tspan = promote_tspan(u0, p, tspan, prob, kwargs) | ||
1195 | |||
1196 | remake(prob; u0 = u0, tspan = tspan, p = p, constant_lags = constant_lags) | ||
1197 | end | ||
1198 | |||
1199 | # Most are extensions | ||
1200 | promote_tspan(u0, p, tspan, prob, kwargs) = _promote_tspan(tspan, kwargs) | ||
1201 | function _promote_tspan(tspan, kwargs) | ||
1202 | if (dt = get(kwargs, :dt, nothing)) !== nothing | ||
1203 | tspan1, tspan2, _ = promote(tspan..., dt) | ||
1204 | return (tspan1, tspan2) | ||
1205 | else | ||
1206 | return tspan | ||
1207 | end | ||
1208 | end | ||
1209 | |||
1210 | function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} | ||
1211 | # Ensure our jacobian will be of the same type as u0 | ||
1212 | uElType = u0 === nothing ? Float64 : eltype(u0) | ||
1213 | if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray | ||
1214 | f = @set f.jac_prototype = similar(f.jac_prototype, uElType) | ||
1215 | end | ||
1216 | |||
1217 | @static if VERSION >= v"1.8-" | ||
1218 | f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) && | ||
1219 | # Some reinitialization code still uses NLSolvers stuff which doesn't | ||
1220 | # properly tag, so opt-out if potentially a mass matrix DAE | ||
1221 | f.mass_matrix isa UniformScaling && | ||
1222 | # Jacobians don't wrap, so just ignore those cases | ||
1223 | f.jac === nothing && | ||
1224 | ((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any && | ||
1225 | RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && | ||
1226 | one(t) === oneunit(t) && | ||
1227 | Tricks.static_hasmethod(ArrayInterface.promote_eltype, | ||
1228 | Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && | ||
1229 | Tricks.static_hasmethod(promote_rule, | ||
1230 | Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && | ||
1231 | Tricks.static_hasmethod(promote_rule, | ||
1232 | Tuple{Type{eltype(u0)}, Type{typeof(t)}})) || | ||
1233 | (specialize === SciMLBase.FunctionWrapperSpecialize && | ||
1234 | !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) | ||
1235 | return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) | ||
1236 | else | ||
1237 | return f | ||
1238 | end | ||
1239 | else | ||
1240 | return f | ||
1241 | end | ||
1242 | end | ||
1243 | |||
1244 | function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t) where {specialize} | ||
1245 | typeof(f.cache) === typeof(u0) && isinplace(f) ? f : remake(f, cache = zero(u0)) | ||
1246 | end | ||
1247 | prepare_alg(alg, u0, p, f) = alg | ||
1248 | |||
1249 | function get_concrete_tspan(prob, isadapt, kwargs, p) | ||
1250 | if prob.tspan isa Function | ||
1251 | tspan = prob.tspan(p) | ||
1252 | elseif haskey(kwargs, :tspan) | ||
1253 | tspan = kwargs[:tspan] | ||
1254 | elseif prob.tspan === (nothing, nothing) | ||
1255 | throw(NoTspanError()) | ||
1256 | else | ||
1257 | tspan = prob.tspan | ||
1258 | end | ||
1259 | |||
1260 | isadapt && eltype(tspan) <: Integer && (tspan = float.(tspan)) | ||
1261 | |||
1262 | any(isnan, tspan) && throw(NaNTspanError()) | ||
1263 | |||
1264 | tspan | ||
1265 | end | ||
1266 | |||
1267 | function isconcreteu0(prob, t0, kwargs) | ||
1268 | !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) | ||
1269 | end | ||
1270 | |||
1271 | function isconcretedu0(prob, t0, kwargs) | ||
1272 | !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) | ||
1273 | end | ||
1274 | |||
1275 | function get_concrete_u0(prob, isadapt, t0, kwargs) | ||
1276 | if eval_u0(prob.u0) | ||
1277 | u0 = prob.u0(prob.p, t0) | ||
1278 | elseif haskey(kwargs, :u0) | ||
1279 | u0 = kwargs[:u0] | ||
1280 | else | ||
1281 | u0 = prob.u0 | ||
1282 | end | ||
1283 | |||
1284 | isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) | ||
1285 | |||
1286 | _u0 = handle_distribution_u0(u0) | ||
1287 | |||
1288 | if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) | ||
1289 | throw(IncompatibleInitialConditionError()) | ||
1290 | end | ||
1291 | |||
1292 | nu0 = length(something(_u0, ())) | ||
1293 | if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && | ||
1294 | prob.f.mass_matrix isa AbstractArray && | ||
1295 | size(prob.f.mass_matrix, 1) !== nu0 | ||
1296 | throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) | ||
1297 | end | ||
1298 | |||
1299 | if _u0 isa Tuple | ||
1300 | throw(TupleStateError()) | ||
1301 | end | ||
1302 | |||
1303 | _u0 | ||
1304 | end | ||
1305 | |||
1306 | function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) | ||
1307 | if haskey(kwargs, :u0) | ||
1308 | u0 = kwargs[:u0] | ||
1309 | else | ||
1310 | u0 = prob.u0 | ||
1311 | end | ||
1312 | |||
1313 | isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) | ||
1314 | |||
1315 | _u0 = handle_distribution_u0(u0) | ||
1316 | |||
1317 | if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) | ||
1318 | throw(IncompatibleInitialConditionError()) | ||
1319 | end | ||
1320 | |||
1321 | if _u0 isa Tuple | ||
1322 | throw(TupleStateError()) | ||
1323 | end | ||
1324 | |||
1325 | return _u0 | ||
1326 | end | ||
1327 | |||
1328 | function get_concrete_du0(prob, isadapt, t0, kwargs) | ||
1329 | if eval_u0(prob.du0) | ||
1330 | du0 = prob.du0(prob.p, t0) | ||
1331 | elseif haskey(kwargs, :du0) | ||
1332 | du0 = kwargs[:du0] | ||
1333 | else | ||
1334 | du0 = prob.du0 | ||
1335 | end | ||
1336 | |||
1337 | isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) | ||
1338 | |||
1339 | _du0 = handle_distribution_u0(du0) | ||
1340 | |||
1341 | if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) | ||
1342 | throw(IncompatibleInitialConditionError()) | ||
1343 | end | ||
1344 | |||
1345 | _du0 | ||
1346 | end | ||
1347 | |||
1348 | function get_concrete_p(prob, kwargs) | ||
1349 | if haskey(kwargs, :p) | ||
1350 | p = kwargs[:p] | ||
1351 | else | ||
1352 | p = prob.p | ||
1353 | end | ||
1354 | end | ||
1355 | |||
1356 | handle_distribution_u0(_u0) = _u0 | ||
1357 | |||
1358 | eval_u0(u0::Function) = true | ||
1359 | eval_u0(u0) = false | ||
1360 | |||
1361 | function __solve(prob::AbstractDEProblem, args...; default_set = false, second_time = false, | ||
1362 | kwargs...) | ||
1363 | if second_time | ||
1364 | throw(NoDefaultAlgorithmError()) | ||
1365 | elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) | ||
1366 | throw(NonSolverError()) | ||
1367 | else | ||
1368 | __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) | ||
1369 | end | ||
1370 | end | ||
1371 | |||
1372 | function __init(prob::AbstractDEProblem, args...; default_set = false, second_time = false, | ||
1373 | kwargs...) | ||
1374 | if second_time | ||
1375 | throw(NoDefaultAlgorithmError()) | ||
1376 | elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) | ||
1377 | throw(NonSolverError()) | ||
1378 | else | ||
1379 | __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) | ||
1380 | end | ||
1381 | end | ||
1382 | |||
1383 | function check_prob_alg_pairing(prob, alg) | ||
1384 | if prob isa ODEProblem && !(alg isa AbstractODEAlgorithm) || | ||
1385 | prob isa SDEProblem && !(alg isa AbstractSDEAlgorithm) || | ||
1386 | prob isa SDDEProblem && !(alg isa AbstractSDEAlgorithm) || | ||
1387 | prob isa DDEProblem && !(alg isa AbstractDDEAlgorithm) || | ||
1388 | prob isa DAEProblem && !(alg isa AbstractDAEAlgorithm) || | ||
1389 | prob isa SteadyStateProblem && !(alg isa AbstractSteadyStateAlgorithm) | ||
1390 | throw(ProblemSolverPairingError(prob, alg)) | ||
1391 | end | ||
1392 | |||
1393 | if isdefined(prob, :u0) && eltype(prob.u0) <: ForwardDiff.Dual && | ||
1394 | !SciMLBase.isautodifferentiable(alg) | ||
1395 | throw(DirectAutodiffError()) | ||
1396 | end | ||
1397 | |||
1398 | if prob isa SDEProblem && prob.noise_rate_prototype !== nothing && | ||
1399 | prob.noise !== nothing && | ||
1400 | size(prob.noise_rate_prototype, 2) != length(prob.noise.W[1]) | ||
1401 | throw(NoiseSizeIncompatabilityError(size(prob.noise_rate_prototype, 2), | ||
1402 | length(prob.noise.W[1]))) | ||
1403 | end | ||
1404 | |||
1405 | # Complex number support comes before arbitrary number support for a more direct | ||
1406 | # error message. | ||
1407 | if !SciMLBase.allowscomplex(alg) | ||
1408 | if isdefined(prob, :u0) && | ||
1409 | RecursiveArrayTools.recursive_unitless_eltype(prob.u0) <: Complex | ||
1410 | throw(ComplexSupportError(alg)) | ||
1411 | end | ||
1412 | end | ||
1413 | |||
1414 | if isdefined(prob, :tspan) && eltype(prob.tspan) <: Complex | ||
1415 | throw(ComplexTspanError()) | ||
1416 | end | ||
1417 | |||
1418 | # Check for concrete element type so that the non-concrete case throws a better error | ||
1419 | if !SciMLBase.allows_arbitrary_number_types(alg) | ||
1420 | if isdefined(prob, :u0) | ||
1421 | uType = RecursiveArrayTools.recursive_unitless_eltype(prob.u0) | ||
1422 | if Base.isconcretetype(uType) && | ||
1423 | !(uType <: Union{Float32, Float64, ComplexF32, ComplexF64}) | ||
1424 | throw(GenericNumberTypeError(alg, | ||
1425 | isdefined(prob, :u0) ? typeof(prob.u0) : | ||
1426 | nothing, | ||
1427 | isdefined(prob, :tspan) ? typeof(prob.tspan) : | ||
1428 | nothing)) | ||
1429 | end | ||
1430 | end | ||
1431 | |||
1432 | if isdefined(prob, :tspan) | ||
1433 | tType = eltype(prob.tspan) | ||
1434 | if Base.isconcretetype(tType) && | ||
1435 | !(tType <: Union{Float32, Float64, ComplexF32, ComplexF64}) | ||
1436 | throw(GenericNumberTypeError(alg, | ||
1437 | isdefined(prob, :u0) ? typeof(prob.u0) : | ||
1438 | nothing, | ||
1439 | isdefined(prob, :tspan) ? typeof(prob.tspan) : | ||
1440 | nothing)) | ||
1441 | end | ||
1442 | end | ||
1443 | end | ||
1444 | end | ||
1445 | |||
1446 | @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) | ||
1447 | if isempty(solve_args) || isnothing(first(solve_args)) | ||
1448 | if haskey(solve_kwargs, :alg) | ||
1449 | solve_kwargs[:alg] | ||
1450 | elseif haskey(prob_kwargs, :alg) | ||
1451 | prob_kwargs[:alg] | ||
1452 | else | ||
1453 | nothing | ||
1454 | end | ||
1455 | elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && | ||
1456 | !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) | ||
1457 | first(solve_args) | ||
1458 | else | ||
1459 | nothing | ||
1460 | end | ||
1461 | end | ||
1462 | |||
1463 | ################### Differentiation | ||
1464 | |||
1465 | """ | ||
1466 | Ignores all adjoint definitions (i.e. `sensealg`) and proceeds to do standard | ||
1467 | AD through the `solve` functions. Generally only used internally for implementing | ||
1468 | discrete sensitivity algorithms. | ||
1469 | """ | ||
1470 | struct SensitivityADPassThrough <: AbstractDEAlgorithm end | ||
1471 | |||
1472 | ### | ||
1473 | ### Legacy Dispatches to be Non-Breaking | ||
1474 | ### | ||
1475 | |||
1476 | @deprecate concrete_solve(prob::AbstractDEProblem, | ||
1477 | alg::Union{AbstractDEAlgorithm, Nothing}, | ||
1478 | u0 = prob.u0, p = prob.p, args...; kwargs...) solve(prob, alg, | ||
1479 | args...; | ||
1480 | u0 = u0, | ||
1481 | p = p, | ||
1482 | kwargs...) | ||
1483 | |||
1484 | function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, | ||
1485 | kwargs...) | ||
1486 | alg = extract_alg(args, kwargs, prob.kwargs) | ||
1487 | if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling | ||
1488 | _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, | ||
1489 | p = p, kwargs...) | ||
1490 | else | ||
1491 | _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) | ||
1492 | end | ||
1493 | |||
1494 | if has_kwargs(_prob) | ||
1495 | if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) | ||
1496 | kwargs_temp = NamedTuple{ | ||
1497 | Base.diff_names(Base._nt_names(values(kwargs)), | ||
1498 | (:callback,))}(values(kwargs)) | ||
1499 | callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], | ||
1500 | values(kwargs).callback),)) | ||
1501 | kwargs = merge(kwargs_temp, callbacks) | ||
1502 | end | ||
1503 | kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) | ||
1504 | end | ||
1505 | |||
1506 | if length(args) > 1 | ||
1507 | _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator, | ||
1508 | Base.tail(args)...; kwargs...) | ||
1509 | else | ||
1510 | _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) | ||
1511 | end | ||
1512 | end | ||
1513 | |||
1514 | function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, | ||
1515 | kwargs...) | ||
1516 | alg = extract_alg(args, kwargs, prob.kwargs) | ||
1517 | if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling | ||
1518 | _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, | ||
1519 | p = p, kwargs...) | ||
1520 | else | ||
1521 | _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) | ||
1522 | end | ||
1523 | |||
1524 | if has_kwargs(_prob) | ||
1525 | if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) | ||
1526 | kwargs_temp = NamedTuple{ | ||
1527 | Base.diff_names(Base._nt_names(values(kwargs)), | ||
1528 | (:callback,))}(values(kwargs)) | ||
1529 | callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], | ||
1530 | values(kwargs).callback),)) | ||
1531 | kwargs = merge(kwargs_temp, callbacks) | ||
1532 | end | ||
1533 | kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) | ||
1534 | end | ||
1535 | |||
1536 | if length(args) > 1 | ||
1537 | _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator, | ||
1538 | Base.tail(args)...; kwargs...) | ||
1539 | else | ||
1540 | _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) | ||
1541 | end | ||
1542 | end | ||
1543 | |||
1544 | #### | ||
1545 | # Catch undefined AD overload cases | ||
1546 | |||
1547 | const ADJOINT_NOT_FOUND_MESSAGE = """ | ||
1548 | Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl. | ||
1549 | Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` | ||
1550 | for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. | ||
1551 | """ | ||
1552 | |||
1553 | struct AdjointNotFoundError <: Exception end | ||
1554 | |||
1555 | function Base.showerror(io::IO, e::AdjointNotFoundError) | ||
1556 | print(io, ADJOINT_NOT_FOUND_MESSAGE) | ||
1557 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
1558 | end | ||
1559 | |||
1560 | function _concrete_solve_adjoint(args...; kwargs...) | ||
1561 | throw(AdjointNotFoundError()) | ||
1562 | end | ||
1563 | |||
1564 | const FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE = """ | ||
1565 | Compatibility with forward-mode automatic differentiation requires SciMLSensitivity.jl. | ||
1566 | Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` | ||
1567 | for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. | ||
1568 | """ | ||
1569 | |||
1570 | struct ForwardSensitivityNotFoundError <: Exception end | ||
1571 | |||
1572 | function Base.showerror(io::IO, e::ForwardSensitivityNotFoundError) | ||
1573 | print(io, FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE) | ||
1574 | println(io, TruncatedStacktraces.VERBOSE_MSG) | ||
1575 | end | ||
1576 | |||
1577 | function _concrete_solve_forward(args...; kwargs...) | ||
1578 | throw(ForwardSensitivityNotFoundError()) | ||
1579 | end |