From 37c563a39f5524b86cafb022810e0547ee426f56 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 7 Oct 2023 03:41:50 +0330 Subject: [PATCH 01/12] simplify `:alg` extraction --- src/solve.jl | 111 ++++++++++++++++++++------------------------------- 1 file changed, 43 insertions(+), 68 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index b530ba38f..f78793c83 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -501,36 +501,19 @@ function init(prob::AbstractJumpProblem, args...; kwargs...) end function init_up(prob::DEProblem, sensealg, u0, p, args...; kwargs...) - if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = kwargs[:alg] + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) # Default algorithm handling + _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); kwargs...) + init_call(_prob, args...; kwargs...) + else _prob = get_concrete_problem(prob, isadaptive(alg); kwargs...) _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) check_prob_alg_pairing(_prob, alg) # alg for improved inference - - if length(args) <= 1 - init_call(_prob, _alg; kwargs...) - else + if length(args) > 1 init_call(_prob, _alg, Base.tail(args)...; kwargs...) - end - elseif haskey(prob.kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = prob.kwargs[:alg] - _prob = get_concrete_problem(prob, isadaptive(alg); kwargs...) - _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) - check_prob_alg_pairing(_prob, alg) # alg for improved inference - if length(args) <= 1 - init_call(_prob, _alg; kwargs...) else - init_call(_prob, _alg, Base.tail(args)...; kwargs...) + init_call(_prob, _alg; kwargs...) end - elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm - alg = args[1] - _prob = get_concrete_problem(prob, isadaptive(alg); kwargs...) - check_prob_alg_pairing(_prob, alg) - _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) - init_call(_prob, _alg, Base.tail(args)...; kwargs...) - else # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); kwargs...) - init_call(_prob, args...; kwargs...) end end @@ -1005,36 +988,20 @@ end function solve_up(prob::Union{DEProblem, NonlinearProblem}, sensealg, u0, p, args...; kwargs...) - if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = kwargs[:alg] + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) # Default algorithm handling + _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + p = p, kwargs...) + solve_call(_prob, args...; kwargs...) + else _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) check_prob_alg_pairing(_prob, alg) # use alg for improved inference - if length(args) <= 1 - solve_call(_prob, _alg; kwargs...) - else + if length(args) > 1 solve_call(_prob, _alg, Base.tail(args)...; kwargs...) - end - elseif haskey(prob.kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = prob.kwargs[:alg] - _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) - check_prob_alg_pairing(_prob, alg) # use alg for improved inference - if length(args) <= 1 - solve_call(_prob, _alg; kwargs...) else - solve_call(_prob, _alg, Base.tail(args)...; kwargs...) + solve_call(_prob, _alg; kwargs...) end - elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm - alg = args[1] - _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) - check_prob_alg_pairing(_prob, alg) # use alg for improved inference - solve_call(_prob, _alg, Base.tail(args)...; kwargs...) - else # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, - p = p, kwargs...) - solve_call(_prob, args...; kwargs...) end end @@ -1424,16 +1391,12 @@ end function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) - alg, _prob = if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = kwargs[:alg] - alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm - alg = args[1] - alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - else # Default algorithm handling - alg = isempty(args) ? nothing : args[1] - alg, get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, - kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) # Default algorithm handling + _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) end if has_kwargs(_prob) @@ -1458,16 +1421,12 @@ end function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) - alg, _prob = if haskey(kwargs, :alg) && (isempty(args) || args[1] === nothing) - alg = kwargs[:alg] - alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm - alg = args[1] - alg, get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) - else # Default algorithm handling - alg = isempty(args) ? nothing : args[1] - alg, get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, - kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) # Default algorithm handling + _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) end if has_kwargs(_prob) @@ -1490,6 +1449,22 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end end +@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) + if isempty(solve_args) || isnothing(solve_args[1]) + if haskey(solve_kwargs, :alg) + solve_kwargs[:alg] + elseif haskey(prob_kwargs, :alg) + prob_kwargs[:alg] + else + nothing + end + elseif solve_args[1] isa DEAlgorithm + solve_args[1] + else + nothing + end +end + #### # Catch undefined AD overload cases From e2183d9afbeeb302154b77bf949bdad8a345697f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 7 Oct 2023 03:52:01 +0330 Subject: [PATCH 02/12] pass `u0` & `p` --- src/solve.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index f78793c83..76454839c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -503,10 +503,11 @@ end function init_up(prob::DEProblem, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); kwargs...) + _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + p = p, kwargs...) init_call(_prob, args...; kwargs...) else - _prob = get_concrete_problem(prob, isadaptive(alg); kwargs...) + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) _alg = prepare_alg(alg, _prob.u0, _prob.p, _prob) check_prob_alg_pairing(_prob, alg) # alg for improved inference if length(args) > 1 From 70f5d22f0d9b7d84daacbf920b313c39957054c4 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 7 Oct 2023 17:52:28 +0330 Subject: [PATCH 03/12] replace deprecated types --- ext/DiffEqBaseReverseDiffExt.jl | 12 ++++---- ext/DiffEqBaseTrackerExt.jl | 6 ++-- src/DiffEqBase.jl | 6 ++-- src/solve.jl | 50 ++++++++++++++++----------------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 96e2d4301..cc0371c5e 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -76,7 +76,7 @@ end end # `ReverseDiff.TrackedArray` -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::ReverseDiff.TrackedArray, @@ -84,7 +84,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::ReverseDiff.TrackedArray, @@ -92,7 +92,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::ReverseDiff.TrackedArray, p, @@ -101,7 +101,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, end # `AbstractArray{<:ReverseDiff.TrackedReal}` -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, @@ -112,7 +112,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, @@ -121,7 +121,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl index 362537f60..618d2d017 100644 --- a/ext/DiffEqBaseTrackerExt.jl +++ b/ext/DiffEqBaseTrackerExt.jl @@ -67,7 +67,7 @@ end end @inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u) -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, @@ -75,7 +75,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, p, args...; @@ -83,7 +83,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, +function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::Tracker.TrackedArray, args...; diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 0053aae55..0ece946cb 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -48,19 +48,19 @@ PrecompileTools.@recompile_invalidations begin import PreallocationTools import FunctionWrappersWrappers - + using SciMLBase using SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator - using SciMLBase: @def, DEIntegrator, DEProblem, + using SciMLBase: @def, DEIntegrator, AbstractDEProblem, AbstractDiffEqInterpolation, DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, AbstractDiscreteCallback, AbstractLinearProblem, AbstractNonlinearProblem, AbstractOptimizationProblem, AbstractSteadyStateProblem, AbstractJumpProblem, AbstractNoiseProblem, AbstractEnsembleProblem, AbstractDynamicalODEProblem, - DEAlgorithm, StandardODEProblem, AbstractIntegralProblem, + AbstractDEAlgorithm, StandardODEProblem, AbstractIntegralProblem, AbstractSensitivityAlgorithm, AbstractODEAlgorithm, AbstractSDEAlgorithm, AbstractDDEAlgorithm, AbstractDAEAlgorithm, AbstractSDDEAlgorithm, AbstractRODEAlgorithm, DAEInitializationAlgorithm, diff --git a/src/solve.jl b/src/solve.jl index 76454839c..9db213086 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -7,7 +7,7 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem, AbstractIntegralProblem, AbstractSteadyStateProblem, AbstractJumpProblem} -has_kwargs(_prob::DEProblem) = has_kwargs(typeof(_prob)) +has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) has_kwargs(::Type{T}) where {T} = __has_kwargs(T) @@ -196,7 +196,7 @@ end const NON_SOLVER_MESSAGE = """ The arguments to solve are incorrect. The second argument must be a solver choice, `solve(prob,alg)` - where `alg` is a `<: DEAlgorithm`, e.g. `Tsit5()`. + where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. Please double check the arguments being sent to the solver. @@ -484,7 +484,7 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin end end -function init(prob::Union{DEProblem, NonlinearProblem}, args...; sensealg = nothing, +function init(prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -500,7 +500,7 @@ function init(prob::AbstractJumpProblem, args...; kwargs...) init_call(prob, args...; kwargs...) end -function init_up(prob::DEProblem, sensealg, u0, p, args...; kwargs...) +function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, @@ -566,7 +566,7 @@ mutable struct NullODEIntegrator{IIP, ProbType, T, SolType, F, P} <: f::F p::P end -function build_null_integrator(prob::DEProblem, args...; +function build_null_integrator(prob::AbstractDEProblem, args...; kwargs...) sol = solve(prob, args...; kwargs...) return NullODEIntegrator{isinplace(prob), typeof(prob), eltype(prob.tspan), typeof(sol), @@ -592,7 +592,7 @@ function step!(integ::NullODEIntegrator, dt = nothing, stop_at_tdt = false) return nothing end -function build_null_solution(prob::DEProblem, args...; +function build_null_solution(prob::AbstractDEProblem, args...; saveat = (), save_everystep = true, save_on = true, @@ -635,7 +635,7 @@ end """ ```julia -solve(prob::DEProblem, alg::Union{DEAlgorithm,Nothing}; kwargs...) +solve(prob::AbstractDEProblem, alg::Union{AbstractDEAlgorithm,Nothing}; kwargs...) ``` ## Arguments @@ -914,7 +914,7 @@ the extension to other types is straightforward. to save size or because the user does not care about the others. Finally, with `progress = true` you are enabling the progress bar. """ -function solve(prob::DEProblem, args...; sensealg = nothing, +function solve(prob::AbstractDEProblem, args...; sensealg = nothing, u0 = nothing, p = nothing, wrap = Val(true), kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -987,8 +987,8 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, end end -function solve_up(prob::Union{DEProblem, NonlinearProblem}, sensealg, u0, p, args...; - kwargs...) +function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, + args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, @@ -1070,12 +1070,12 @@ function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...) prob end -function solve(prob::PDEProblem, alg::DiffEqBase.DEAlgorithm, args...; +function solve(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; kwargs...) solve(prob.prob, alg, args...; kwargs...) end -function init(prob::PDEProblem, alg::DiffEqBase.DEAlgorithm, args...; +function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; kwargs...) init(prob.prob, alg, args...; kwargs...) end @@ -1266,27 +1266,27 @@ handle_distribution_u0(_u0) = _u0 eval_u0(u0::Function) = true eval_u0(u0) = false -function __solve(prob::DEProblem, args...; default_set = false, second_time = false, +function __solve(prob::AbstractDEProblem, args...; default_set = false, second_time = false, kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, DEAlgorithm}) + elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else - __solve(prob::DEProblem, nothing, args...; default_set = false, second_time = true, - kwargs...) + __solve(prob::AbstractDEProblem, nothing, args...; default_set = false, + second_time = true, kwargs...) end end -function __init(prob::DEProblem, args...; default_set = false, second_time = false, +function __init(prob::AbstractDEProblem, args...; default_set = false, second_time = false, kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, DEAlgorithm}) + elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else - __init(prob::DEProblem, nothing, args...; default_set = false, second_time = true, - kwargs...) + __init(prob::AbstractDEProblem, nothing, args...; default_set = false, + second_time = true, kwargs...) end end @@ -1360,7 +1360,7 @@ Ignores all adjoint definitions (i.e. `sensealg`) and proceeds to do standard AD through the `solve` functions. Generally only used internally for implementing discrete sensitivity algorithms. """ -struct SensitivityADPassThrough <: SciMLBase.DEAlgorithm end +struct SensitivityADPassThrough <: AbstractDEAlgorithm end function ChainRulesCore.frule(::typeof(solve_up), prob, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, @@ -1370,7 +1370,7 @@ function ChainRulesCore.frule(::typeof(solve_up), prob, kwargs...) end -function ChainRulesCore.rrule(::typeof(solve_up), prob::SciMLBase.DEProblem, +function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, u0, p, args...; kwargs...) @@ -1382,8 +1382,8 @@ end ### Legacy Dispatches to be Non-Breaking ### -@deprecate concrete_solve(prob::SciMLBase.DEProblem, - alg::Union{SciMLBase.DEAlgorithm, Nothing}, +@deprecate concrete_solve(prob::AbstractDEProblem, + alg::Union{AbstractDEAlgorithm, Nothing}, u0 = prob.u0, p = prob.p, args...; kwargs...) solve(prob, alg, args...; u0 = u0, @@ -1459,7 +1459,7 @@ end else nothing end - elseif solve_args[1] isa DEAlgorithm + elseif solve_args[1] isa AbstractDEAlgorithm solve_args[1] else nothing From dc0e5c1a05ea63e4627e424fa492c8f391fcded4 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 7 Oct 2023 17:56:39 +0330 Subject: [PATCH 04/12] fix unnecessary type check --- src/solve.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 9db213086..76201a01c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1273,8 +1273,7 @@ function __solve(prob::AbstractDEProblem, args...; default_set = false, second_t elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else - __solve(prob::AbstractDEProblem, nothing, args...; default_set = false, - second_time = true, kwargs...) + __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) end end @@ -1285,8 +1284,7 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else - __init(prob::AbstractDEProblem, nothing, args...; default_set = false, - second_time = true, kwargs...) + __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) end end From bbe7e8581d94bb235133c3d63a5e89be6bb8b7d4 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 01:09:09 +0330 Subject: [PATCH 05/12] use `AbstractSciMLAlgorithm` --- src/solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 76201a01c..98ce33924 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1457,7 +1457,7 @@ end else nothing end - elseif solve_args[1] isa AbstractDEAlgorithm + elseif solve_args[1] isa SciMLBase.AbstractSciMLAlgorithm solve_args[1] else nothing From a390882248d3dfec8f06da19cab9bbb6941f5810 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 01:09:37 +0330 Subject: [PATCH 06/12] use `first` instead of `[1]` --- src/solve.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 98ce33924..046613aaa 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -573,7 +573,7 @@ function build_null_integrator(prob::AbstractDEProblem, args...; typeof(prob.f), typeof(prob.p), }(Float64[], Float64[], - first(prob.tspan), + prob.tspan[1], prob, sol, prob.f, @@ -1015,7 +1015,7 @@ function solve_call(prob::SteadyStateProblem, end function solve(prob::EnsembleProblem, args...; kwargs...) - if isempty(args) || length(args) == 1 && typeof(args[1]) <: EnsembleAlgorithm + if isempty(args) || length(args) == 1 && typeof(first(args)) <: EnsembleAlgorithm __solve(prob, nothing, args...; kwargs...) else __solve(prob, args...; kwargs...) @@ -1270,7 +1270,7 @@ function __solve(prob::AbstractDEProblem, args...; default_set = false, second_t kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) + elseif length(args) > 0 && !(typeof(first(args)) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1281,7 +1281,7 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm}) + elseif length(args) > 0 && !(typeof(first(args)) <: Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1449,7 +1449,7 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) - if isempty(solve_args) || isnothing(solve_args[1]) + if isempty(solve_args) || isnothing(first(solve_args)) if haskey(solve_kwargs, :alg) solve_kwargs[:alg] elseif haskey(prob_kwargs, :alg) @@ -1457,8 +1457,8 @@ end else nothing end - elseif solve_args[1] isa SciMLBase.AbstractSciMLAlgorithm - solve_args[1] + elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm + first(solve_args) else nothing end From bf1eeb78b35e13a0a39afdfd31b4441f75b88e61 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 01:20:17 +0330 Subject: [PATCH 07/12] update `solve(prob::EnsembleProblem` --- src/solve.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 046613aaa..7fb9feb54 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1015,10 +1015,11 @@ function solve_call(prob::SteadyStateProblem, end function solve(prob::EnsembleProblem, args...; kwargs...) - if isempty(args) || length(args) == 1 && typeof(first(args)) <: EnsembleAlgorithm - __solve(prob, nothing, args...; kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if length(args) > 1 + __solve(prob, alg, Base.tail(args)...; kwargs...) else - __solve(prob, args...; kwargs...) + __solve(prob, alg; kwargs...) end end function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) From 93045ea62cee4dc8d21a8a5f907d5e6f7748b7c6 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 01:21:30 +0330 Subject: [PATCH 08/12] move `extract_alg` to a better place --- src/solve.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 7fb9feb54..4c71fc554 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1352,6 +1352,22 @@ function check_prob_alg_pairing(prob, alg) end end +@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) + if isempty(solve_args) || isnothing(first(solve_args)) + if haskey(solve_kwargs, :alg) + solve_kwargs[:alg] + elseif haskey(prob_kwargs, :alg) + prob_kwargs[:alg] + else + nothing + end + elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm + first(solve_args) + else + nothing + end +end + ################### Differentiation """ @@ -1449,22 +1465,6 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end end -@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) - if isempty(solve_args) || isnothing(first(solve_args)) - if haskey(solve_kwargs, :alg) - solve_kwargs[:alg] - elseif haskey(prob_kwargs, :alg) - prob_kwargs[:alg] - else - nothing - end - elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm - first(solve_args) - else - nothing - end -end - #### # Catch undefined AD overload cases From 820758cd0cb53154d40155299e45792366758aa6 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 01:32:07 +0330 Subject: [PATCH 09/12] fix: `alg` shouldn't be repeated --- src/solve.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 4c71fc554..525fa4e3b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1427,11 +1427,11 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end - if isempty(args) - _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) - else + if length(args) > 1 _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator, Base.tail(args)...; kwargs...) + else + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) end end @@ -1457,11 +1457,11 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end - if isempty(args) - _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) - else + if length(args) > 1 _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator, Base.tail(args)...; kwargs...) + else + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) end end From 1574722d7cf36b533fc70d99e37007a6548a83f4 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 03:19:06 +0330 Subject: [PATCH 10/12] fix: `isadaptive` error --- src/solve.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 525fa4e3b..8ea6fa391 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -502,7 +502,7 @@ end function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) # Default algorithm handling + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, kwargs...) init_call(_prob, args...; kwargs...) @@ -990,7 +990,7 @@ end function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) # Default algorithm handling + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, kwargs...) solve_call(_prob, args...; kwargs...) @@ -1408,7 +1408,7 @@ end function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) # Default algorithm handling + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, kwargs...) else @@ -1438,7 +1438,7 @@ end function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) # Default algorithm handling + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, p = p, kwargs...) else From c1b95fce8e1bb3535054950a13ffe106b4dc7a7b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 03:28:43 +0330 Subject: [PATCH 11/12] replace `typeof(...) <:` by `isa` --- src/solve.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 8ea6fa391..6ec3c8d66 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -477,7 +477,7 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin if _prob isa Union{ODEProblem, DAEProblem} && isnothing(_prob.u0) build_null_integrator(_prob, args...; kwargs...) elseif hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && - typeof(_prob.f.f) <: EvalFunc + _prob.f.f isa EvalFunc Base.invokelatest(__init, _prob, args...; kwargs...)#::T else __init(_prob, args...; kwargs...)#::T @@ -503,7 +503,7 @@ end function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) init_call(_prob, args...; kwargs...) else @@ -549,7 +549,7 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi end if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && - typeof(_prob.f.f) <: EvalFunc + _prob.f.f isa EvalFunc Base.invokelatest(__solve, _prob, args...; kwargs...)#::T else __solve(_prob, args...; kwargs...)#::T @@ -991,7 +991,7 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0 args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) solve_call(_prob, args...; kwargs...) else @@ -1271,7 +1271,7 @@ function __solve(prob::AbstractDEProblem, args...; default_set = false, second_t kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(first(args)) <: Union{Nothing, AbstractDEAlgorithm}) + elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1282,7 +1282,7 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(typeof(first(args)) <: Union{Nothing, AbstractDEAlgorithm}) + elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) throw(NonSolverError()) else __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1409,7 +1409,7 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) else _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) @@ -1439,7 +1439,7 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) else _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) From 319fd6250f0d1acc0256912ae4982898e577b14b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 19 Oct 2023 11:58:56 +0330 Subject: [PATCH 12/12] fix : "type EnsembleProblem has no field kwargs" --- src/solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 6ec3c8d66..2c96f72a6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1015,7 +1015,7 @@ function solve_call(prob::SteadyStateProblem, end function solve(prob::EnsembleProblem, args...; kwargs...) - alg = extract_alg(args, kwargs, prob.kwargs) + alg = extract_alg(args, kwargs, kwargs) if length(args) > 1 __solve(prob, alg, Base.tail(args)...; kwargs...) else