From 7f6658d39b277a33b4fa1077b121c21e1783e79e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Mar 2024 13:32:23 +0530 Subject: [PATCH] feat: implement interpolation for parameter timeseries --- src/solutions/ode_solutions.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index c9c023863..589331a53 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -189,7 +189,9 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") if is_parameter(sol, idxs) - return getp(sol, idxs)(sol) + unknown_tidx = searchsortedfirst(sol.t, t; lt = <=) - 1 + ps = parameter_values_at_state_time(sol, unknown_tidx) + return getp(sol, idxs)(ps) else return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1] end @@ -200,14 +202,19 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) - [is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs] + unknown_tidx = searchsortedfirst(sol.t, t; lt = <=) - 1 + ps = parameter_values_at_state_time(sol, unknown_tidx) + [is_parameter(sol, idx) ? getp(sol, idx)(ps) : first(interp_sol[idx]) for idx in idxs] end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") if is_parameter(sol, idxs) - return getp(sol, idxs)(sol) + unknown_tidxs = searchsortedfirst((sol.t,), t; lt = <=) - 1 + pss = parameter_values_at_state_time.((sol,), unknown_tidxs) + getter = getp(sol, idxs) + return getter.(pss) else interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing @@ -222,7 +229,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing return DiffEqArray( - [[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol) + [[is_parameter(sol, idx) ? getp(sol, idx)(interp_sol, i) : interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol) end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},