diff --git a/Project.toml b/Project.toml index 645bbfa88..38151cef3 100644 --- a/Project.toml +++ b/Project.toml @@ -76,7 +76,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.8.0" +RecursiveArrayTools = "3.14.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" @@ -84,7 +84,7 @@ SciMLStructures = "1.1" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.15" +SymbolicIndexingInterface = "0.3.20" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 6ad75a3be..46bc70f92 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -172,6 +172,9 @@ end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if eltype(sol.u) <: Number + idxs = only(idxs) + end sol.interp(t, idxs, deriv, sol.prob.p, continuity) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, @@ -183,6 +186,9 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if eltype(sol.u) <: Number + idxs = only(idxs) + end A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing return DiffEqArray(A.u, A.t, p, sol) @@ -203,7 +209,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) - [is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs] + first(interp_sol[idxs]) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, @@ -224,8 +230,9 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + indexed_sol = interp_sol[idxs] return DiffEqArray( - [[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol) + [indexed_sol[i] for i in 1:length(t)], t, p, sol) end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f72284ee9..609132a79 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -436,6 +436,18 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, plot_vecs = [] labels = String[] varsyms = variable_symbols(sol) + batch_symbolic_vars = [] + for x in vars + for j in 2:length(x) + if (x[j] isa Integer && x[j] == 0) || isequal(x[j], getindepsym_defaultt(sol)) + else + push!(batch_symbolic_vars, x[j]) + end + end + end + batch_symbolic_vars = identity.(batch_symbolic_vars) + indexed_solution = sol(plott; idxs = batch_symbolic_vars) + idxx = 0 for x in vars tmp = [] strs = String[] @@ -444,7 +456,8 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, push!(tmp, plott) push!(strs, "t") else - push!(tmp, sol(plott; idxs = x[j])) + idxx += 1 + push!(tmp, indexed_solution[idxx, :]) if !isempty(varsyms) && x[j] isa Integer push!(strs, String(getname(varsyms[x[j]]))) elseif hasname(x[j]) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 78d47822c..61a6e3b17 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -137,7 +137,8 @@ eqs = [D(x) ~ Hold(ud) xd ~ Sample(t, dt)(x)] @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [p3 => 2p1]) prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0), - [p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, xd(k - 1) => 4.0, xd(k - 2) => 5.0]) + [p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, + xd(k - 1) => 4.0, xd(k - 2) => 5.0, yd(k - 1) => 0.0]) # parameter dependencies prob2 = @inferred ODEProblem remake(prob; p = [p1 => 2.0]) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 893fdf437..98d07c5ba 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -93,9 +93,9 @@ end @test length(sol[(lorenz1.x, lorenz2.x)]) == length(sol) @test all(length.(sol[(lorenz1.x, lorenz2.x)]) .== 2) -@test sol[[lorenz1.x, lorenz2.x], :] isa Matrix{Float64} -@test size(sol[[lorenz1.x, lorenz2.x], :]) == (2, length(sol)) -@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :]) +@test sol[[lorenz1.x, lorenz2.x], :] isa Vector{Vector{Float64}} +@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol) +@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2 @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1]