Skip to content

Commit

Permalink
Merge pull request #750 from AayushSabharwal/as/plot-idxs-func
Browse files Browse the repository at this point in the history
fix: handle more plotting cases, add more comprehensive tests
  • Loading branch information
ChrisRackauckas authored Jul 26, 2024
2 parents 38c8850 + eccead4 commit 8c4a285
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 49 deletions.
2 changes: 0 additions & 2 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector, continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getu(sol, idxs)
Expand Down
43 changes: 26 additions & 17 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,28 +179,29 @@ end
end

idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs
disc_idxs = []
cont_idxs = []
for idx in (idxs isa Union{Tuple, AbstractArray} ? idxs : [idxs])
tsidxs = get_all_timeseries_indexes(sol, idx)
if ContinuousTimeseries() in tsidxs
push!(cont_idxs, idx)
else
push!(disc_idxs, (idx, only(tsidxs)))
end
end
idxs = identity.(cont_idxs)
if !(idxs isa Union{Tuple, AbstractArray})
vars = interpret_vars([idxs], sol)
else
vars = interpret_vars(idxs, sol)
end
disc_vars = Tuple[]
cont_vars = Tuple[]
for var in vars
tsidxs = union(get_all_timeseries_indexes(sol, var[2]), get_all_timeseries_indexes(sol, var[3]))
if ContinuousTimeseries() in tsidxs
push!(cont_vars, var)
else
push!(disc_vars, (var..., only(tsidxs)))
end
end
idxs = identity.(cont_vars)
vars = identity.(cont_vars)
tdir = sign(sol.t[end] - sol.t[1])
xflip --> tdir < 0
seriestype --> :path

@series begin
if isempty(idxs)
if idxs isa Union{AbstractArray, Tuple} && isempty(idxs)
label --> nothing
([], [])
else
Expand Down Expand Up @@ -281,7 +282,7 @@ end
(plot_vecs...,)
end
end
for (idx, tsidx) in disc_idxs
for (func, xvar, yvar, tsidx) in disc_vars
partition = sol.discretes[tsidx]
ts = current_time(partition)
if tspan !== nothing
Expand All @@ -296,18 +297,26 @@ end
end
ts = ts[tstart:tend]

vals = getp(sol, idx)(sol, tstart:tend)
if symbolic_type(xvar) == NotSymbolic() && xvar == 0
xvar = only(independent_variable_symbols(sol))
end
xvals = sol(ts; idxs = xvar).u
# xvals = getu(sol, xvar)(sol, tstart:tend)
yvals = getp(sol, yvar)(sol, tstart:tend)
tmpvals = map(func, xvals, yvals)
xvals = getindex.(tmpvals, 1)
yvals = getindex.(tmpvals, 2)
# Scatterplot of points
@series begin
seriestype := :line
linestyle --> :dash
markershape --> :o
markersize --> repeat([2, 0], length(ts)-1)
markeralpha --> repeat([1, 0], length(ts)-1)
label --> string(hasname(idx) ? getname(idx) : idx)
label --> string(hasname(yvar) ? getname(yvar) : yvar)

x = vec([ts[1:end-1]'; ts[2:end]'])
y = repeat(vals, inner=2)[1:end-1]
x = vec([xvals[1:end-1]'; xvals[2:end]'])
y = repeat(yvals, inner=2)[1:end-1]
x, y
end
end
Expand Down
49 changes: 48 additions & 1 deletion test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization,
OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase,
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface,
DiffEqCallbacks, Test
DiffEqCallbacks, Test, Plots
using ModelingToolkit: t_nounits as t, D_nounits as D

# Sets rnd number.
Expand Down Expand Up @@ -548,6 +548,21 @@ end
end
end
function SymbolicIndexingInterface.parameter_observed(s::NumSymbolCache, x)
if x isa Symbol
allsyms = all_symbols(s)
x = allsyms[findfirst(y -> hasname(y) && x == getname(y), allsyms)]
elseif x isa AbstractArray
allsyms = all_symbols(s)
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
end
x = newx
end
res = ModelingToolkit.build_function(x,
sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)),
independent_variable_symbols(s)[]; expression = Val(false))
Expand All @@ -567,6 +582,21 @@ end
else
x = ModelingToolkit.unwrap(x)
end
if x isa Symbol
allsyms = all_symbols(s)
x = allsyms[findfirst(y -> hasname(y) && x == getname(y), allsyms)]
elseif x isa AbstractArray
allsyms = all_symbols(s)
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
end
x = newx
end
vars = ModelingToolkit.vars(x)
return mapreduce(union, vars; init = Set()) do sym
if is_variable(s, sym)
Expand Down Expand Up @@ -856,4 +886,21 @@ end
end
@test res == val
end

@testset "Plotting" begin
plotfn(t, u) = (t, 2u)
all_idxs = [ud1, 2ud1, ud2, (plotfn, 0, ud1), (plotfn, t, ud1)]
sym_idxs = [:ud1, :ud2, (plotfn, 0, :ud1), (plotfn, 0, :ud1)]

for idx in Iterators.flatten((all_idxs, sym_idxs))
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple || length(get_all_timeseries_indexes(sol, collect(idx))) > 1)
@test_nowarn plot(sol; idxs = idx)
end
end
end
end
29 changes: 0 additions & 29 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,35 +319,6 @@ integrator[lorenz2.x] = 2.0
# integrator10 = integrator(0.1, idxs = 2)
# @test integrator10 isa Real

using Plots
for idxs in [
(lorenz2.x,lorenz2.z),
(α,lorenz2.z),
(lorenz2.x,α),
α,
(α,),
(t,α),
[lorenz2.x,lorenz2.z],
[α,lorenz2.z],
[lorenz2.x,α],
[α],
[t,α],
]
plot(sol; idxs)
if idxs isa Union{Tuple, AbstractArray}
idxs = map(idxs) do i
hasname(i) ? getname(i) : i
end
if any(==(:t), idxs)
@test_broken plot(sol; idxs)
else
plot(sol; idxs)
end
elseif hasname(idxs)
plot(sol; idxs=getname(idxs))
end
end

using LinearAlgebra
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
ps = @parameters p[1:3] = [1, 2, 3]
Expand Down
24 changes: 24 additions & 0 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,27 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)

sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

@testset "Plot idxs" begin
@variables x(t) y(t)
@parameters p
@mtkbuild sys = ODESystem([D(x) ~ x * t, D(y) ~ y - p * x], t)
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0), [p => 1.0])
sol = solve(prob, Tsit5())

plotfn(t, u) = (t, 2u)
all_idxs = [x, x + p * y, t, (plotfn, 0, 1), (plotfn, t, 1), (plotfn, 0, x),
(plotfn, t, x), (plotfn, t, p * y)]
sym_idxs = [:x, :t, (plotfn, :t, 1), (plotfn, 0, :x),
(plotfn, :t, :x)]
for idx in Iterators.flatten((all_idxs, sym_idxs))
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple)
@test_nowarn plot(sol; idxs = idx)
end
end
end

0 comments on commit 8c4a285

Please sign in to comment.