Skip to content

Commit

Permalink
fixup! test: rewrite hybrid system tests to not use MTK
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 27, 2024
1 parent af0d5ed commit 3e6d723
Showing 1 changed file with 87 additions and 29 deletions.
116 changes: 87 additions & 29 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization,
OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase,
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface,
DiffeqCallbacks, Test
DiffEqCallbacks, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

# Sets rnd number.
Expand Down Expand Up @@ -536,7 +536,8 @@ end
end
SymbolicIndexingInterface.symbolic_container(s::NumSymbolCache) = s.sc
function SymbolicIndexingInterface.is_observed(s::NumSymbolCache, x)
return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) && !is_parameter(s, x) && !is_independent_variable(s, x)
return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) &&
!is_parameter(s, x) && !is_independent_variable(s, x)
end
function SymbolicIndexingInterface.observed(s::NumSymbolCache, x)
res = ModelingToolkit.build_function(x,
Expand Down Expand Up @@ -584,7 +585,8 @@ end
end
end
end
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(p::Vector{Float64}, args...)
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
::NumSymbolCache, p::Vector{Float64}, args...)
for (idx, buf) in args
if idx == 1
p[1:2] .= buf
Expand All @@ -604,7 +606,7 @@ end
dea2 = DiffEqArray(Vector{Float64}[], Float64[])
return ParameterTimeseriesCollection((dea1, dea2), deepcopy(ps))
end
function SciMLBase.get_saveable_values(p::Vector{Float64}, tsidx)
function SciMLBase.get_saveable_values(::NumSymbolCache, p::Vector{Float64}, tsidx)
if tsidx == 1
return p[1:2]
else
Expand All @@ -614,38 +616,96 @@ end

@variables x(t) ud1(t) ud2(t) xd1(t) xd2(t)
@parameters kp
sc = SymbolCache([x], Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5), t; timeseries_parameters = Dict(ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2), ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2)))
sc = SymbolCache([x],
Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5),
t;
timeseries_parameters = Dict(
ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2),
ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2)))
sys = NumSymbolCache(sc)

function f!(du, u, p, t)
du .= u .* t .+ p[5] * sum(u)
end
fn = ODEFunction(f!; sys = sys)
prob = ODEProblem(fn, [1.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0, 5.0])
cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true, save_positions = (false, false)) do integ
cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true,
save_positions = (false, false)) do integ
integ.p[1:2] .+= exp(-integ.t)
SciMLBase.save_discretes!(integ, 1)
end
function affect2!(integ)
integ.p[3:4] .+= only(integ.u)
SciMLBase.save_discretes!(integ, 2)
end
cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false), initialize = (c, u, t, integ) -> affect2!(integ))
cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false),
initialize = (c, u, t, integ) -> affect2!(integ))
sol = solve(deepcopy(prob), Tsit5(); callback = CallbackSet(cb1, cb2))

ud1val = getindex.(sol.discretes.collection[1].u, 1)
xd1val = getindex.(sol.discretes.collection[1].u, 2)
ud2val = getindex.(sol.discretes.collection[2].u, 1)
xd2val = getindex.(sol.discretes.collection[2].u, 2)

for (sym, timeseries_index, val, buffer, isobs, check_inference) in [
(ud1, 1, ud1val, zeros(length(ud1val)), false, true)
([ud1, xd1], 1, vcat.(ud1val, xd1val), map(_ -> zeros(2), ud1val), false, true)
((ud2, xd2), 2, tuple.(ud2val, xd2val), map(_ -> zeros(2), ud2val), false, true)
(ud2 + xd2, 2, ud2val .+ xd2val, zeros(length(ud2val)), true, true)
([ud2 + xd2, ud2 * xd2], 2, vcat.(ud2val .+ xd2val, ud2val .* xd2val), map(_ -> zeros(2), ud2val), true, true)
((ud1 + xd1, ud1 * xd1), 1, tuple.(ud1val .+ xd1val, ud1val .* xd1val), map(_ -> zeros(2), ud1val), true, true)
]
for (sym, timeseries_index, val, buffer, isobs, check_inference) in [(ud1,
1,
ud1val,
zeros(length(ud1val)),
false,
true)
([ud1, xd1],
1,
vcat.(ud1val,
xd1val),
map(
_ -> zeros(2),
ud1val),
false,
true)
((ud2, xd2),
2,
tuple.(ud2val,
xd2val),
map(
_ -> zeros(2),
ud2val),
false,
true)
(ud2 + xd2,
2,
ud2val .+
xd2val,
zeros(length(ud2val)),
true,
true)
(
[ud2 + xd2,
ud2 * xd2],
2,
vcat.(
ud2val .+
xd2val,
ud2val .*
xd2val),
map(
_ -> zeros(2),
ud2val),
true,
true)
(
(ud1 + xd1,
ud1 * xd1),
1,
tuple.(
ud1val .+
xd1val,
ud1val .*
xd1val),
map(
_ -> zeros(2),
ud1val),
true,
true)]
getter = getp(sys, sym)
if check_inference
@inferred getter(sol)
Expand Down Expand Up @@ -678,7 +738,8 @@ end
end
end

for subidx in [1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5]
for subidx in [
1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5]
if check_inference
@inferred getter(sol, subidx)
if !isa(val[subidx], Number)
Expand Down Expand Up @@ -706,9 +767,7 @@ end
(ud2, xd1, xd2),
ud1 + ud2,
[ud1 + ud2, ud1 * xd1],
(ud1 + ud2, ud1 * xd1),

]
(ud1 + ud2, ud1 * xd1)]
getter = getp(sys, sym)
@test_throws Exception getter(sol)
@test_throws Exception getter([], sol)
Expand All @@ -732,7 +791,7 @@ end
((kp, x), true, tuple.(kpval, xval), false),
(2ud2, true, 2 .* ud2val, true),
([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false),
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false),
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false)
]
getter = getu(sys, sym)
if check_inference
Expand All @@ -741,8 +800,8 @@ end
@test getter(sol) == val
reference = val_is_timeseries ? val : xval
for subidx in [
1, CartesianIndex(2), : ,rand(Bool, length(reference)),
rand(eachindex(reference), 4), 2:6
1, CartesianIndex(2), :, rand(Bool, length(reference)),
rand(eachindex(reference), 4), 2:6
]
if check_inference
@inferred getter(sol, subidx)
Expand All @@ -767,7 +826,7 @@ end
((x, ud1), (_xval, _ud1val), true),
(x + ud2, _xval + _ud2val, true),
([2x, 3xd1], [2_xval, 3_xd1val], true),
((2x, 3xd2), (2_xval, 3_xd2val), true),
((2x, 3xd2), (2_xval, 3_xd2val), true)
]
getter = getu(sys, sym)
@test_throws Exception getter(sol)
Expand All @@ -781,8 +840,8 @@ end
@test getter(integ) == val
end

xinterp = sol(0.1:0.1:0.3, idxs=x)
xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs=x)
xinterp = sol(0.1:0.1:0.3, idxs = x)
xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs = x)
ud1interp = ud1val[2:4]
ud2interp = ud2val[2:4]

Expand All @@ -796,13 +855,12 @@ end
(x, c2[2], xinterp2[1]),
(x, c2[2:4], xinterp2),
([x, ud2], c2[2], [xinterp2[1], ud2interp[1]]),
([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp)),
([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp))
]
res = sol(t, idxs=sym)
res = sol(t, idxs = sym)
if res isa DiffEqArray
res = res.u
end
@test res == val
end
end
end

0 comments on commit 3e6d723

Please sign in to comment.