From 318371e1d3b60e95c35e723ce01e7ec82a5c4496 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Jun 2024 17:36:43 +0530 Subject: [PATCH 1/5] refactor: use `getu`/`setu` for indexing everything --- src/integrator_interface.jl | 75 ++++++++--------------------- src/problems/problem_interface.jl | 67 +++++++++++--------------- src/solutions/solution_interface.jl | 27 +++++------ 3 files changed, 60 insertions(+), 109 deletions(-) diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index eec89165b..a9576de91 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -465,46 +465,20 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol) end end -Base.@propagate_inbounds function _getindex(A::DEIntegrator, - ::NotSymbolic, - I::Union{Int, AbstractArray{Int}, - CartesianIndex, Colon, BitArray, - AbstractArray{Bool}}...) - A.u[I...] -end - -Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym) - if is_variable(A, sym) - return A[variable_index(A, sym)] - elseif is_parameter(A, sym) - error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing") - elseif is_independent_variable(A, sym) - return A.t - elseif is_observed(A, sym) - return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t) - else - error("Tried to index integrator with a Symbol that was not found in the system.") +Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym) + if is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.") end + return getu(A, sym)(A) end -Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym) - return A[collect(sym)] -end - -Base.@propagate_inbounds function _getindex( - A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}) - return getindex.((A,), sym) -end - -Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym) - symtype = symbolic_type(sym) - elsymtype = symbolic_type(eltype(sym)) - - if symtype != NotSymbolic() - return _getindex(A, symtype, sym) - else - return _getindex(A, elsymtype, sym) +Base.@propagate_inbounds function Base.getindex( + A::DEIntegrator, sym::Union{AbstractArray, Tuple}) + if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) || + is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.") end + return getu(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -522,25 +496,18 @@ function observed(A::DEIntegrator, sym) end function Base.setindex!(A::DEIntegrator, val, sym) - has_sys(A.f) || - error("Invalid indexing of integrator: Integrator does not support indexing without a system") - if symbolic_type(sym) == ScalarSymbolic() - if is_variable(A, sym) - set_state!(A, val, variable_index(A, sym)) - elseif is_parameter(A, sym) - error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.") - else - error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.") - end - return A - elseif symbolic_type(sym) == ArraySymbolic() - setindex!.((A,), val, collect(sym)) - return A - else - sym isa AbstractArray || error("Invalid indexing of integrator") - setindex!.((A,), val, sym) - return A + if is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.") + end + setu(A, sym)(A, val) +end + +function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple}) + if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) || + is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.") end + setu(A, sym)(A, val) end ### Integrator traits diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index 437d33fcf..ece143ab7 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -38,51 +38,38 @@ Base.@propagate_inbounds function Base.getindex( return getindex(prob, all_variable_symbols(prob)) end -Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym) - if symbolic_type(sym) == ScalarSymbolic() - if is_variable(prob, sym) - return state_values(prob, variable_index(prob, sym)) - elseif is_parameter(prob, sym) - error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.") - elseif is_independent_variable(prob, sym) - return current_time(prob) - elseif is_observed(prob, sym) - obs = SymbolicIndexingInterface.observed(prob, sym) - if is_time_dependent(prob) - return obs(state_values(prob), parameter_values(prob), current_time(prob)) - else - return obs(state_values(prob), parameter_values(prob)) - end - else - error("Invalid indexing of problem: $sym is not a state, parameter, or independent variable") - end - elseif symbolic_type(sym) == ArraySymbolic() - return map(s -> prob[s], collect(sym)) - else - sym isa AbstractArray || error("Invalid indexing of problem") - return map(s -> prob[s], sym) +Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym) + if is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.") end + return getu(A, sym)(A) +end + +Base.@propagate_inbounds function Base.getindex( + A::AbstractSciMLProblem, sym::Union{AbstractArray, Tuple}) + if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) || + is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.") + end + return getu(A, sym)(A) end function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...) ___internal_setindex!(prob::AbstractSciMLProblem, args...; kwargs...) end -function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym) - if symbolic_type(sym) == ScalarSymbolic() - if is_variable(prob, sym) - set_state!(prob, val, variable_index(prob, sym)) - elseif is_parameter(prob, sym) - error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.") - else - error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.") - end - return prob - elseif symbolic_type(sym) == ArraySymbolic() - setindex!.((prob,), val, collect(sym)) - return prob - else - sym isa AbstractArray || error("Invalid indexing of problem") - setindex!.((prob,), val, sym) - return prob + +function ___internal_setindex!(A::AbstractSciMLProblem, val, sym) + if is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.") + end + return setu(A, sym)(A, val) +end + +function ___internal_setindex!( + A::AbstractSciMLProblem, val, sym::Union{AbstractArray, Tuple}) + if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) || + is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.") end + return setu(A, sym)(A, val) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 22f1e0448..22fdb62a1 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -63,22 +63,19 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, : end Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) - if symbolic_type(sym) == ScalarSymbolic() - if is_variable(A, sym) - return A[variable_index(A, sym)] - elseif is_parameter(A, sym) - error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.") - elseif is_observed(A, sym) - return SymbolicIndexingInterface.observed(A, sym)(A.u, parameter_values(A)) - else - error("Tried to index solution with a Symbol that was not found in the system.") - end - elseif symbolic_type(sym) == ArraySymbolic() - return A[collect(sym)] - else - sym isa AbstractArray || error("Invalid indexing of solution") - return getindex.((A,), sym) + if is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.") + end + return getu(A, sym)(A) +end + +Base.@propagate_inbounds function Base.getindex( + A::AbstractNoTimeSolution, sym::Union{AbstractArray, Tuple}) + if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) || + is_parameter(A, sym) + error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.") end + return getu(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( From d61dd25a620c445ebb1636e7a421c56f2ae69eb8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Jun 2024 17:37:04 +0530 Subject: [PATCH 2/5] fix: make `parameter_values` inferred for `AbstractOptimizationCache` --- src/solutions/optimization_solutions.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index 6f7afe420..b7daeca99 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -174,7 +174,13 @@ function reinit!(cache::SciMLBase.AbstractOptimizationCache; p = missing, return cache end -SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache) = x.p +function SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache) + if has_reinit(x) + x.reinit_cache.p + else + x.p + end +end SymbolicIndexingInterface.symbolic_container(x::AbstractOptimizationCache) = x.f get_p(sol::OptimizationSolution) = sol.cache.p From 14c963904e76eb44383a6b262d73b3b86adebddb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Jun 2024 17:38:11 +0530 Subject: [PATCH 3/5] test: comprehensively test symbolic indexing --- Project.toml | 4 +- test/downstream/Project.toml | 4 + test/downstream/comprehensive_indexing.jl | 530 ++++++++++++++++++++++ test/downstream/solution_interface.jl | 42 ++ test/downstream/symbol_indexing.jl | 389 ---------------- test/runtests.jl | 2 +- 6 files changed, 580 insertions(+), 391 deletions(-) create mode 100644 test/downstream/comprehensive_indexing.jl delete mode 100644 test/downstream/symbol_indexing.jl diff --git a/Project.toml b/Project.toml index 7e664674e..2d909506f 100644 --- a/Project.toml +++ b/Project.toml @@ -83,6 +83,7 @@ Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" SciMLStructures = "1.1" +StableRNGs = "1.0" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" @@ -106,6 +107,7 @@ PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -113,4 +115,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"] +test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"] diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 8f932fb22..97c10c226 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -13,6 +13,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" @@ -34,6 +37,7 @@ RecursiveArrayTools = "3" SciMLBase = "2" SciMLSensitivity = "7.11" SciMLStructures = "1.1" +SteadyStateDiffEq = "2.2" Sundials = "4.11" SymbolicIndexingInterface = "0.3" SymbolicUtils = "<1.6, 2" diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl new file mode 100644 index 000000000..050ef7dde --- /dev/null +++ b/test/downstream/comprehensive_indexing.jl @@ -0,0 +1,530 @@ +using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization, + OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase, + SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, Test +using ModelingToolkit: t_nounits as t, D_nounits as D + +# Sets rnd number. +using StableRNGs +rng = StableRNG(12345) +seed = rand(rng, 1:100) + +### Basic Tests ### + +# Prepares a model systems. +begin + # Prepare system components. + @parameters kp kd k1 k2 + @variables begin + X(t), [bounds = (-10.0, 10.0)] + Y(t), [bounds = (-10.0, 10.0)] + XY(t) + end + alg_eqs = [0 ~ kp - kd * X - k1 * X + k2 * Y + 0 ~ 1 + k1 * X - k2 * Y - Y] + diff_eqs = [D(X) ~ kp - kd * X - k1 * X + k2 * Y + D(Y) ~ 1 + k1 * X - k2 * Y - Y] + noise_eqs = [ + sqrt(kp + X), + sqrt(k1 + Y) + ] + jumps = [ + ConstantRateJump(kp, [X ~ X + 1]), + ConstantRateJump(kd * X, [X ~ X - 1]), + ConstantRateJump(k1 * X, [X ~ X - 1, Y ~ Y + 1]), + ConstantRateJump(k2 * Y, [X ~ X + 1, Y ~ Y - 1]), + ConstantRateJump(1, [Y ~ Y + 1]), + ConstantRateJump(Y, [Y ~ Y - 1]) + ] + observed = [XY ~ X + Y] + loss = kd * (k1 - X)^2 + k2 * (kp * Y - X^2)^2 + + # Create systems (without structural_simplify, since that might modify systems to affect intended tests). + osys = complete(ODESystem(diff_eqs, t; observed, name = :osys)) + ssys = complete(SDESystem( + diff_eqs, noise_eqs, t, [X, Y], [kp, kd, k1, k2]; observed, name = :ssys)) + jsys = complete(JumpSystem(jumps, t, [X, Y], [kp, kd, k1, k2]; observed, name = :jsys)) + nsys = complete(NonlinearSystem(alg_eqs; observed, name = :nsys)) + optsys = complete(OptimizationSystem( + loss, [X, Y], [kp, kd, k1, k2]; observed, name = :optsys)) +end + +# Prepares problems, integrators, and solutions. +begin + # Sets problem inputs (to be used for all problem creations). + u0_vals = [X => 4.0, Y => 5.0] + tspan = (0.0, 10.0) + p_vals = [kp => 1.0, kd => 0.1, k1 => 0.25, k2 => 0.5] + + # Creates problems. + oprob = ODEProblem(osys, u0_vals, tspan, p_vals) + sprob = SDEProblem(ssys, u0_vals, tspan, p_vals) + dprob = DiscreteProblem(jsys, u0_vals, tspan, p_vals) + jprob = JumpProblem(jsys, deepcopy(dprob), Direct(); rng) + nprob = NonlinearProblem(nsys, u0_vals, p_vals) + ssprob = SteadyStateProblem(osys, u0_vals, p_vals) + optprob = OptimizationProblem(optsys, u0_vals, p_vals, grad = true, hess = true) + problems = [oprob, sprob, dprob, jprob, nprob, ssprob, optprob] + systems = [osys, ssys, jsys, jsys, nsys, osys, optsys] + + # Creates an `EnsembleProblem` for each problem. + eoprob = EnsembleProblem(oprob) + esprob = EnsembleProblem(sprob) + edprob = EnsembleProblem(dprob) + ejprob = EnsembleProblem(jprob) + enprob = EnsembleProblem(nprob) + essprob = EnsembleProblem(ssprob) + eoptprob = EnsembleProblem(optprob) + eproblems = [eoprob, esprob, edprob, ejprob, enprob, essprob, optprob] + esystems = [osys, ssys, jsys, jsys, nsys, osys, optsys] + + # Creates integrators. + oint = init(oprob, Tsit5(); save_everystep = false) + sint = init(sprob, ImplicitEM(); save_everystep = false) + jint = init(jprob, SSAStepper()) + nint = init(nprob, NewtonRaphson(); save_everystep = false) + @test_broken ssint = init(ssprob, DynamicSS(Tsit5()); save_everystep = false) # https://github.com/SciML/SteadyStateDiffEq.jl/issues/79 + integrators = [oint, sint, jint, nint] + integsystems = [osys, ssys, jsys, nsys] + + # Creates solutions. + osol = solve(oprob, Tsit5()) + ssol = solve(sprob, ImplicitEM(); seed) + jsol = solve(jprob, SSAStepper(); seed) + nsol = solve(nprob, NewtonRaphson()) + sssol = solve(ssprob, DynamicSS(Tsit5())) + optsol = solve(optprob, GradientDescent()) + sols = [osol, ssol, jsol, nsol, sssol, optsol] +end + +non_timeseries_objects = [problems; eproblems; integrators; [nsol]; [sssol]; [optsol]] +non_timeseries_systems = [systems; esystems; integsystems; nsys; osys; optsys] +timeseries_objects = [osol, ssol, jsol] +timeseries_systems = [osys, ssys, jsys] + +@testset "Non-timeseries indexing $(SciMLBase.parameterless_type(valp))" for (valp, indp) in zip( + deepcopy(non_timeseries_objects), non_timeseries_systems) + if valp isa SciMLBase.NonlinearSolution && valp.prob isa SteadyStateProblem + # Steady state problem indexing is broken, since the system is time-dependent but + # the solution isn't + @test_broken false + continue + end + u = state_values(valp) + uidxs = variable_index.((indp,), [X, Y]) + @testset "State indexing" begin + for (sym, val, newval) in [(X, u[uidxs[1]], 4.0) + (indp.X, u[uidxs[1]], 4.0) + (:X, u[uidxs[1]], 4.0) + (uidxs[1], u[uidxs[1]], 4.0) + ([X, Y], u[uidxs], 4ones(2)) + ([indp.X, indp.Y], u[uidxs], 4ones(2)) + ([:X, :Y], u[uidxs], 4ones(2)) + (uidxs, u[uidxs], 4ones(2)) + ((X, Y), Tuple(u[uidxs]), (4.0, 4.0)) + ((indp.X, indp.Y), Tuple(u[uidxs]), (4.0, 4.0)) + ((:X, :Y), Tuple(u[uidxs]), (4.0, 4.0)) + (Tuple(uidxs), Tuple(u[uidxs]), (4.0, 4.0))] + get = getu(indp, sym) + set! = setu(indp, sym) + @inferred get(valp) + @test get(valp) == val + if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray} + @test_broken valp[sym] + else + @test valp[sym] == val + end + + if !(valp isa SciMLBase.AbstractNoTimeSolution) + @inferred set!(valp, newval) + @test get(valp) == newval + set!(valp, val) + @test get(valp) == val + + if !(valp isa JumpProblem) || !(sym isa Union{Tuple, AbstractArray}) + valp[sym] = newval + @test valp[sym] == newval + valp[sym] = val + @test valp[sym] == val + end + end + end + end + + @testset "Observed" begin + # Observed functions don't infer + for (sym, val) in [(XY, sum(u)) + (indp.XY, sum(u)) + (:XY, sum(u)) + ([X, indp.Y, :XY, X * Y], [u[uidxs]..., sum(u), prod(u)]) + ((X, indp.Y, :XY, X * Y), (u[uidxs]..., sum(u), prod(u))) + (X * Y, prod(u))] + get = getu(indp, sym) + @test get(valp) == val + end + end + + getter = getu(indp, []) + @test getter(valp) == [] + + p = getindex.((Dict(p_vals),), [kp, kd, k1, k2]) + newp = p .* 10 + pidxs = parameter_index.((indp,), [kp, kd, k1, k2]) + @testset "Parameter indexing" begin + for (sym, oldval, newval) in [(kp, p[1], newp[1]) + (indp.kp, p[1], newp[1]) + (:kp, p[1], newp[1]) + (pidxs[1], p[1], newp[1]) + ([kp, kd], p[1:2], newp[1:2]) + ([indp.kp, indp.kd], p[1:2], newp[1:2]) + ([:kp, :kd], p[1:2], newp[1:2]) + (pidxs[1:2], p[1:2], newp[1:2]) + ((kp, kd), Tuple(p[1:2]), Tuple(newp[1:2])) + ((indp.kp, indp.kd), Tuple(p[1:2]), Tuple(newp[1:2])) + ((:kp, :kd), Tuple(p[1:2]), Tuple(newp[1:2])) + (Tuple(pidxs[1:2]), Tuple(p[1:2]), Tuple(newp[1:2]))] + get = getp(indp, sym) + set! = setp(indp, sym) + + @inferred get(valp) + @test get(valp) == valp.ps[sym] + @test get(valp) == oldval + + if !(valp isa SciMLBase.AbstractNoTimeSolution) + @inferred set!(valp, newval) + @test get(valp) == newval + set!(valp, oldval) + @test get(valp) == oldval + + valp.ps[sym] = newval + @test get(valp) == newval + valp.ps[sym] = oldval + @test get(valp) == oldval + end + end + getter = getp(indp, []) + @test getter(valp) == [] + end +end + +@testset "Timeseries indexing $(SciMLBase.parameterless_type(valp))" for (valp, indp) in zip( + timeseries_objects, timeseries_systems) + @info SciMLBase.parameterless_type(valp) typeof(indp) + u = state_values(valp) + uidxs = variable_index.((indp,), [X, Y]) + xvals = getindex.(valp.u, uidxs[1]) + yvals = getindex.(valp.u, uidxs[2]) + xyvals = xvals .+ yvals + tvals = valp.t + @testset "State indexing and observed" begin + for (sym, val, check_inference, check_getindex) in [(X, xvals, true, true) + (indp.X, xvals, true, true) + (:X, xvals, true, true) + (uidxs[1], xvals, true, false) + ([X, Y], vcat.(xvals, yvals), + true, true) + ([indp.X, indp.Y], + vcat.(xvals, yvals), + true, true) + ([:X, :Y], + vcat.(xvals, yvals), + true, true) + (uidxs, vcat.(xvals, yvals), + true, false) + ((Y, X), + tuple.(yvals, xvals), + true, true) + ((indp.Y, indp.X), + tuple.(yvals, xvals), + true, true) + ((:Y, :X), + tuple.(yvals, xvals), + true, true) + (Tuple(reverse(uidxs)), + tuple.(yvals, xvals), + true, false) + (t, tvals, true, true) + (:t, tvals, true, true) + ([X, t], vcat.(xvals, tvals), + false, true) + ((Y, t), + tuple.(yvals, tvals), + true, true) + ([], + [[] + for _ in 1:length(tvals)], + false, + false) + (XY, xyvals, true, true) + (indp.XY, xyvals, true, true) + (:XY, xyvals, true, true) + ([X, indp.Y, :XY, X * Y], + vcat.(xvals, yvals, xyvals, + xvals .* yvals), + false, + true) + ((X, indp.Y, :XY, X * Y), + tuple.( + xvals, yvals, xyvals, + xvals .* yvals), + false, + true) + (X * Y, xvals .* yvals, + false, true)] + get = getu(indp, sym) + if check_inference + @inferred get(valp) + end + @test get(valp) == val + if check_getindex + @test valp[sym] == val + end + # TODO: Test more subindexes when they're supported + for i in [rand(eachindex(val)), CartesianIndex(1)] + if check_inference + @inferred get(valp, i) + end + @test get(valp, i) == val[i] + if check_getindex + @test valp[sym, i] == val[i] + end + end + end + end + + p = getindex.((Dict(p_vals),), [kp, kd, k1, k2]) + pidxs = parameter_index.((indp,), [kp, kd, k1, k2]) + + @testset "Parameter indexing" begin + for (sym, oldval) in [(kp, p[1]) + (indp.kp, p[1]) + (:kp, p[1]) + (pidxs[1], p[1]) + ([kp, kd], p[1:2]) + ([indp.kp, indp.kd], p[1:2]) + ([:kp, :kd], p[1:2]) + (pidxs[1:2], p[1:2]) + ((kp, kd), Tuple(p[1:2])) + ((indp.kp, indp.kd), Tuple(p[1:2])) + ((:kp, :kd), Tuple(p[1:2])) + (Tuple(pidxs[1:2]), Tuple(p[1:2]))] + get = getp(indp, sym) + + @inferred get(valp) + @test get(valp) == valp.ps[sym] + @test get(valp) == oldval + end + getter = getp(indp, []) + @test getter(valp) == [] + end + + @testset "Interpolation" begin + sol = valp + interpolated_sol = sol(0.0:1.0:10.0) + @test interpolated_sol[XY] isa Vector + @test interpolated_sol[XY, :] isa Vector + @test interpolated_sol[XY, 2] isa Float64 + @test length(interpolated_sol[XY, 1:5]) == 5 + @test interpolated_sol[XY] ≈ interpolated_sol[X] .+ interpolated_sol[Y] + @test collect(interpolated_sol[t]) isa Vector + @test collect(interpolated_sol[t, :]) isa Vector + @test interpolated_sol[t, 2] isa Float64 + @test length(interpolated_sol[t, 1:5]) == 5 + + sol3 = sol(0.0:1.0:10.0, idxs = [X, Y]) + @test sol3.u isa Vector + @test first(sol3.u) isa Vector + @test length(sol3.u) == 11 + @test length(sol3.t) == 11 + @test collect(sol3[t]) ≈ sol3.t + @test collect(sol3[t, 1:5]) ≈ sol3.t[1:5] + @test sol(0.0:1.0:10.0, idxs = [Y, 1]) isa RecursiveArrayTools.DiffEqArray + + sol4 = sol(0.1, idxs = [X, Y]) + @test sol4 isa Vector + @test length(sol4) == 2 + @test first(sol4) isa Real + @test sol(0.1, idxs = [Y, 1]) isa Vector{<:Real} + + sol5 = sol(0.0:1.0:10.0, idxs = X) + @test sol5.u isa Vector + @test first(sol5.u) isa Real + @test length(sol5.u) == 11 + @test length(sol5.t) == 11 + @test collect(sol5[t]) ≈ sol3.t + @test collect(sol5[t, 1:5]) ≈ sol3.t[1:5] + @test_throws Any sol(0.0:1.0:10.0, idxs = 1.2) + + sol6 = sol(0.1, idxs = X) + @test sol6 isa Real + @test_throws Any sol(0.1, idxs = 1.2) + end +end + +@testset "ODE with array symbolics" begin + sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 + ps = @parameters p[1:3] = [1, 2, 3] + eqs = [collect(D.(x) .~ x) + D(y) ~ norm(x) * y - x[1]] + @named sys = ODESystem(eqs, t, [sts...;], ps) + sys = complete(sys) + prob = ODEProblem(sys, [], (0, 1.0)) + sol = solve(prob, Tsit5()) + # interpolation of array variables + @test sol(1.0, idxs = x) == [sol(1.0, idxs = x[i]) for i in 1:3] + + x_idx = variable_index.((sys,), [x[1], x[2], x[3]]) + y_idx = variable_index(sys, y) + x_val = getindex.(sol.u, (x_idx,)) + y_val = getindex.(sol.u, y_idx) + obs_val = getindex.(x_val, 1) .+ y_val + + @testset "Solution indexing" begin + # don't check inference for weird cases of nested arrays/tuples + for (sym, val, check_inference) in [ + (x, x_val, true), + (sys.x, x_val, true), + (:x, x_val, true), + (x_idx, x_val, true), + (x[1] + sys.y, obs_val, true), + ([x[1], x[2]], getindex.(x_val, ([1, 2],)), true), + ([sys.x[1], sys.x[2]], getindex.(x_val, ([1, 2],)), true), + ([x[1], x_idx[2]], getindex.(x_val, ([1, 2],)), false), + ([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), + ([sys.x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), + ([:x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), + ([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([sys.x, sys.y], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([:x, :y], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x_idx, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, x], [[i, i] for i in x_val], true), + ([sys.x, sys.x], [[i, i] for i in x_val], true), + ([:x, :x], [[i, i] for i in x_val], true), + ([x, x_idx], [[i, i] for i in x_val], false), + ((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((sys.x, sys.y), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((:x, :y), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, x), [(i, i) for i in x_val], true), + ((sys.x, sys.x), [(i, i) for i in x_val], true), + ((:x, :x), [(i, i) for i in x_val], true), + ((x, x_idx), [(i, i) for i in x_val], true), + ((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), + ((sys.x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), + ((:x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), + ((x, (x[1] + y, y)), + [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ([x, [x[1] + y, y]], + [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ((x, [x[1] + y, y], (x[1] + y, y_idx)), + [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ([x, [x[1] + y, y], (x[1] + y, y_idx)], + [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false) + ] + if check_inference + @inferred getu(prob, sym)(sol) + end + @test getu(prob, sym)(sol) == val + end + end + + x_newval = [3.0, 6.0, 9.0] + y_newval = 4.0 + x_probval = prob[x] + y_probval = prob[y] + + @testset "Problem indexing" begin + for (sym, oldval, newval, check_inference) in [ + (x, x_probval, x_newval, true), + (sys.x, x_probval, x_newval, true), + (:x, x_probval, x_newval, true), + (x_idx, x_probval, x_newval, true), + ((x, y), (x_probval, y_probval), (x_newval, y_newval), true), + ((sys.x, sys.y), (x_probval, y_probval), (x_newval, y_newval), true), + ((:x, :y), (x_probval, y_probval), (x_newval, y_newval), true), + ((x_idx, y_idx), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y], [x_probval, y_probval], [x_newval, y_newval], false), + ([sys.x, sys.y], [x_probval, y_probval], [x_newval, y_newval], false), + ([:x, :y], [x_probval, y_probval], [x_newval, y_newval], false), + ([x_idx, y_idx], [x_probval, y_probval], [x_newval, y_newval], false), + ((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false), + ((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true), + ([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false), + ([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], + [x_newval[1:2], [y_newval, x_newval[3]]], false), + ([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], + [x_newval[1:2], (y_newval, x_newval[3])], false), + ((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), + (x_newval[1:2], [y_newval, x_newval[3]]), false), + ((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), + (x_newval[1:2], (y_newval, x_newval[3])), false) + ] + getter = getu(prob, sym) + setter! = setu(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval + end + end + + @testset "Parameter indexing" begin + pval = [1.0, 2.0, 3.0] + pval_new = [4.0, 5.0, 6.0] + + # don't check inference for nested tuples/arrays + for (sym, oldval, newval, check_inference) in [ + (p[1], pval[1], pval_new[1], true), + (p, pval, pval_new, true), + (sys.p, pval, pval_new, true), + (:p, pval, pval_new, true), + ((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true), + ([p[1], p[2]], pval[1:2], pval_new[1:2], true), + ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), + ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), + ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), + (pval_new[1], (pval_new[2],), [pval_new[3]]), false), + ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], + [pval_new[1], (pval_new[2],), [pval_new[3]]], false) + ] + getter = getp(prob, sym) + setter! = setp(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval + end + end +end + +# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2697 +@testset "Interpolation of derivative of observed variables" begin + @variables x(t) y(t) z(t) w(t)[1:2] + @named sys = ODESystem( + [D(x) ~ 1, y ~ x^2, z ~ 2y^2 + 3x, w[1] ~ x + y + z, w[2] ~ z * x * y], t) + sys = structural_simplify(sys) + prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + @test_throws ErrorException sol(1.0, Val{1}, idxs = y) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [y, z]) + @test_throws ErrorException sol(1.0, Val{1}, idxs = w) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w]) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y]) +end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 8bc804f98..c99f86504 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -105,3 +105,45 @@ sol = solve(prob, Rodas4()) @test_throws ArgumentError sol[x] @test in(sol[lorenz1.x], [getindex.(sol.u, 1) for i in 1:length(unknowns(sol.prob.f.sys))]) @test_throws ArgumentError sol[:x] + +### Non-symbolic indexing tests +@test sol[:, 1] isa AbstractVector +@test sol[:, 1:2] isa AbstractDiffEqArray +@test sol[:, [1, 2]] isa AbstractDiffEqArray + +sol1 = sol(0.0:1.0:10.0) +@test sol1.u isa Vector +@test first(sol1.u) isa Vector +@test length(sol1.u) == 11 +@test length(sol1.t) == 11 + +sol2 = sol(0.1) +@test sol2 isa Vector +@test length(sol2) == length(unknowns(sys)) +@test first(sol2) isa Real + +sol3 = sol(0.0:1.0:10.0, idxs = [lorenz1.x, lorenz2.x]) + +sol7 = sol(0.0:1.0:10.0, idxs = [2, 1]) +@test sol7.u isa Vector +@test first(sol7.u) isa Vector +@test length(sol7.u) == 11 +@test length(sol7.t) == 11 +@test collect(sol7[t]) ≈ sol3.t +@test collect(sol7[t, 1:5]) ≈ sol3.t[1:5] + +sol8 = sol(0.1, idxs = [2, 1]) +@test sol8 isa Vector +@test length(sol8) == 2 +@test first(sol8) isa Real + +sol9 = sol(0.0:1.0:10.0, idxs = 2) +@test sol9.u isa Vector +@test first(sol9.u) isa Real +@test length(sol9.u) == 11 +@test length(sol9.t) == 11 +@test collect(sol9[t]) ≈ sol3.t +@test collect(sol9[t, 1:5]) ≈ sol3.t[1:5] + +sol10 = sol(0.1, idxs = 2) +@test sol10 isa Real diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl deleted file mode 100644 index 4952d729c..000000000 --- a/test/downstream/symbol_indexing.jl +++ /dev/null @@ -1,389 +0,0 @@ -using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, - Test -using Optimization, OptimizationOptimJL -using ModelingToolkit: t_nounits as t, D_nounits as D - -@parameters σ ρ β -@variables x(t) y(t) z(t) - -eqs = [D(x) ~ σ * (y - x), - D(y) ~ x * (ρ - z) - y, - D(z) ~ x * y - β * z] - -@named lorenz1 = ODESystem(eqs, t) -@named lorenz2 = ODESystem(eqs, t) - -@parameters γ -@variables a(t) α(t) -connections = [0 ~ lorenz1.x + lorenz2.y + a * γ, - α ~ 2lorenz1.x + a * γ] -@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2]) - -u0 = [lorenz1.x => 1.0, - lorenz1.y => 0.0, - lorenz1.z => 0.0, - lorenz2.x => 0.0, - lorenz2.y => 1.0, - lorenz2.z => 0.0, - a => 2.0] - -p = [lorenz1.σ => 10.0, - lorenz1.ρ => 28.0, - lorenz1.β => 8 / 3, - lorenz2.σ => 10.0, - lorenz2.ρ => 28.0, - lorenz2.β => 8 / 3, - γ => 2.0] - -tspan = (0.0, 100.0) -prob = ODEProblem(sys, u0, tspan, p) -integ = init(prob, Rodas4()) -sol = solve(prob, Rodas4()) - -@testset "indexing should error" begin - for obj in [prob, integ, sol] - for sym in ['a', :b] - @test_throws Any obj[sym] - @test_throws Any obj[sym, 1] - @test_throws Any obj[sym, 1:5] - @test_throws Any obj[sym, [1, 2, 3]] - end - end -end - -@testset "observed shouldn't error" begin - for obj in [prob, integ, sol] - obj[:a] - SymbolicIndexingInterface.observed(obj, :a) - end -end - -@test sol[a] isa AbstractVector -@test sol[:a] == sol[a] -@test sol[a, 1] isa Real -@test sol[:a, 1] == sol[a, 1] == prob[a] == prob[:a] == integ[a] == integ[:a] == -1.0 -@test sol[a, 1:5] isa AbstractVector -@test sol[:a, 1:5] == sol[a, 1:5] -@test sol[a, [1, 2, 3]] isa AbstractVector -@test sol[:a, [1, 2, 3]] == sol[a, [1, 2, 3]] - -@test sol[:, 1] isa AbstractVector -@test sol[:, 1:2] isa AbstractDiffEqArray -@test sol[:, [1, 2]] isa AbstractDiffEqArray - -@test sol[lorenz1.x] isa Vector -@test sol[lorenz1.x, 2] isa Float64 -@test sol[lorenz1.x, :] isa Vector -@test sol[t] isa Vector -@test sol[t, 2] isa Float64 -@test sol[t, :] isa Vector -@test length(sol[lorenz1.x, 1:5]) == 5 -@test sol[α] isa Vector -@test sol[α, 3] isa Float64 -@test length(sol[α, 5:10]) == 6 -@test getp(prob, γ)(sol) isa Real -@test sol.ps[γ] isa Real -@test getp(prob, γ)(sol) == getp(prob, :γ)(sol) == sol.ps[γ] == sol.ps[:γ] == 2.0 -@test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple -@test sol.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple - -@test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}} -@test length(sol[[lorenz1.x, lorenz2.x]]) == length(sol) -@test all(length.(sol[[lorenz1.x, lorenz2.x]]) .== 2) -@test sol[(lorenz1.x, lorenz2.x)] isa Vector{Tuple{Float64, Float64}} -@test length(sol[(lorenz1.x, lorenz2.x)]) == length(sol) -@test all(length.(sol[(lorenz1.x, lorenz2.x)]) .== 2) - -@test sol[[lorenz1.x, lorenz2.x], :] isa Vector{Vector{Float64}} -@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol) -@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2 - -@variables q(t)[1:2] = [1.0, 2.0] -eqs = [D(q[1]) ~ 2q[1] - D(q[2]) ~ 2.0] -@named sys2 = ODESystem(eqs, t, [q...], []) -sys2_simplified = structural_simplify(sys2) -prob2 = ODEProblem(sys2_simplified, [], (0.0, 5.0)) -sol2 = solve(prob2, Tsit5()) - -@test sol2[q] isa Vector{Vector{Float64}} -@test sol2[(q...,)] isa Vector{NTuple{length(q), Float64}} -@test length(sol2[q]) == length(sol2) -@test all(length.(sol2[q]) .== 2) -@test sol2[collect(q)] == sol2[q] - -# Check if indexing using variable names from interpolated solution works -interpolated_sol = sol(0.0:1.0:10.0) -@test interpolated_sol[α] isa Vector -@test interpolated_sol[α, :] isa Vector -@test interpolated_sol[α, 2] isa Float64 -@test length(interpolated_sol[α, 1:5]) == 5 -@test interpolated_sol[α] ≈ 2interpolated_sol[lorenz1.x] .+ interpolated_sol[a] .* 2.0 -@test collect(interpolated_sol[t]) isa Vector -@test collect(interpolated_sol[t, :]) isa Vector -@test interpolated_sol[t, 2] isa Float64 -@test length(interpolated_sol[t, 1:5]) == 5 - -sol1 = sol(0.0:1.0:10.0) -@test sol1.u isa Vector -@test first(sol1.u) isa Vector -@test length(sol1.u) == 11 -@test length(sol1.t) == 11 - -sol2 = sol(0.1) -@test sol2 isa Vector -@test length(sol2) == length(unknowns(sys)) -@test first(sol2) isa Real - -sol3 = sol(0.0:1.0:10.0, idxs = [lorenz1.x, lorenz2.x]) -@test sol3.u isa Vector -@test first(sol3.u) isa Vector -@test length(sol3.u) == 11 -@test length(sol3.t) == 11 -@test collect(sol3[t]) ≈ sol3.t -@test collect(sol3[t, 1:5]) ≈ sol3.t[1:5] -@test sol(0.0:1.0:10.0, idxs = [lorenz1.x, 1]) isa RecursiveArrayTools.DiffEqArray - -sol4 = sol(0.1, idxs = [lorenz1.x, lorenz2.x]) -@test sol4 isa Vector -@test length(sol4) == 2 -@test first(sol4) isa Real -@test sol(0.1, idxs = [lorenz1.x, 1]) isa Vector{<:Real} - -sol5 = sol(0.0:1.0:10.0, idxs = lorenz1.x) -@test sol5.u isa Vector -@test first(sol5.u) isa Real -@test length(sol5.u) == 11 -@test length(sol5.t) == 11 -@test collect(sol5[t]) ≈ sol3.t -@test collect(sol5[t, 1:5]) ≈ sol3.t[1:5] -@test_throws Any sol(0.0:1.0:10.0, idxs = 1.2) - -sol6 = sol(0.1, idxs = lorenz1.x) -@test sol6 isa Real -@test_throws Any sol(0.1, idxs = 1.2) - -sol7 = sol(0.0:1.0:10.0, idxs = [2, 1]) -@test sol7.u isa Vector -@test first(sol7.u) isa Vector -@test length(sol7.u) == 11 -@test length(sol7.t) == 11 -@test collect(sol7[t]) ≈ sol3.t -@test collect(sol7[t, 1:5]) ≈ sol3.t[1:5] - -sol8 = sol(0.1, idxs = [2, 1]) -@test sol8 isa Vector -@test length(sol8) == 2 -@test first(sol8) isa Real - -sol9 = sol(0.0:1.0:10.0, idxs = 2) -@test sol9.u isa Vector -@test first(sol9.u) isa Real -@test length(sol9.u) == 11 -@test length(sol9.t) == 11 -@test collect(sol9[t]) ≈ sol3.t -@test collect(sol9[t, 1:5]) ≈ sol3.t[1:5] - -sol10 = sol(0.1, idxs = 2) -@test sol10 isa Real - -@test is_timeseries(sol) == Timeseries() -getx = getu(sys, lorenz1.x) -get_arr = getu(sys, [lorenz1.x, lorenz2.x]) -get_tuple = getu(sys, (lorenz1.x, lorenz2.x)) -get_obs = getu(sol, lorenz1.x + lorenz2.x) # can't use sys for observed -get_obs_arr = getu(sol, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y]) -l1x_idx = variable_index(sol, lorenz1.x) -l2x_idx = variable_index(sol, lorenz2.x) -l1y_idx = variable_index(sol, lorenz1.y) -l2y_idx = variable_index(sol, lorenz2.y) - -@test getx(sol) == sol[l1x_idx, :] -@test get_arr(sol) == vcat.(sol[l1x_idx, :], sol[l2x_idx, :]) -@test get_tuple(sol) == tuple.(sol[l1x_idx, :], sol[l2x_idx, :]) -@test get_obs(sol) == sol[l1x_idx, :] + sol[l2x_idx, :] -@test get_obs_arr(sol) == - vcat.(sol[l1x_idx, :] + sol[l2x_idx, :], sol[l1y_idx, :] + sol[l2y_idx, :]) - -#= -using Plots -plot(sol,idxs=(lorenz2.x,lorenz2.z)) -plot(sol,idxs=(α,lorenz2.z)) -plot(sol,idxs=(lorenz2.x,α)) -plot(sol,idxs=α) -plot(sol,idxs=(t,α)) -=# - -using LinearAlgebra -sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 -ps = @parameters p[1:3] = [1, 2, 3] -eqs = [collect(D.(x) .~ x) - D(y) ~ norm(x) * y - x[1]] -@named sys = ODESystem(eqs, t, [sts...;], [ps...;]) -sys = complete(sys) -prob = ODEProblem(sys, [], (0, 1.0)) -sol = solve(prob, Tsit5()) -@test sol[x] isa Vector{<:Vector} -@test sol[@nonamespace sys.x] isa Vector{<:Vector} -@test sol.ps[p] == [1, 2, 3] -# interpolation of array variables -@test sol(1.0, idxs = x) == [sol(1.0, idxs = x[i]) for i in 1:3] - -x_idx = variable_index.((sys,), [x[1], x[2], x[3]]) -y_idx = variable_index(sys, y) -x_val = vcat.(getindex.((sol,), x_idx, :)...) -y_val = sol[y_idx, :] -obs_val = sol[x[1] + y] - -# don't check inference for weird cases of nested arrays/tuples -for (sym, val, check_inference) in [ - (x, x_val, true), - (y, y_val, true), - (y_idx, y_val, true), - (x_idx, x_val, true), - (x[1] + y, obs_val, true), - ([x[1], x[2]], sol[[x[1], x[2]]], true), - ([x[1], x_idx[2]], sol[[x[1], x[2]]], false), - ([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), - ([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false), - ([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false), - ([x, x], [[i, i] for i in x_val], true), - ([x, x_idx], [[i, i] for i in x_val], false), - ((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true), - ((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true), - ((x, x), [(i, i) for i in x_val], true), - ((x, x_idx), [(i, i) for i in x_val], true), - ((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), - ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), - ([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), - ((x, [x[1] + y, y], (x[1] + y, y_idx)), - [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), - ([x, [x[1] + y, y], (x[1] + y, y_idx)], - [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false) -] - if check_inference - @inferred getu(prob, sym)(sol) - end - @test getu(prob, sym)(sol) == val -end - -x_newval = [3.0, 6.0, 9.0] -y_newval = 4.0 -x_probval = prob[x] -y_probval = prob[y] - -for (sym, oldval, newval, check_inference) in [ - (x, x_probval, x_newval, true), - (y, y_probval, y_newval, true), - (x_idx, x_probval, x_newval, true), - (y_idx, y_probval, y_newval, true), - ((x, y), (x_probval, y_probval), (x_newval, y_newval), true), - ([x, y], [x_probval, y_probval], [x_newval, y_newval], false), - ((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true), - ([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false), - ((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true), - ([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false), - ([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], - [x_newval[1:2], [y_newval, x_newval[3]]], true), - ([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], - [x_newval[1:2], (y_newval, x_newval[3])], false), - ((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), - (x_newval[1:2], [y_newval, x_newval[3]]), true), - ((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), - (x_newval[1:2], (y_newval, x_newval[3])), true) -] - getter = getu(prob, sym) - setter! = setu(prob, sym) - if check_inference - @inferred getter(prob) - end - @test getter(prob) == oldval - if check_inference - @inferred setter!(prob, newval) - else - setter!(prob, newval) - end - @test getter(prob) == newval - setter!(prob, oldval) - @test getter(prob) == oldval -end - -pval = [1.0, 2.0, 3.0] -pval_new = [4.0, 5.0, 6.0] - -# don't check inference for nested tuples/arrays -for (sym, oldval, newval, check_inference) in [ - (p[1], pval[1], pval_new[1], true), - (p, pval, pval_new, true), - ((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true), - ([p[1], p[2]], pval[1:2], pval_new[1:2], true), - ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), - ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), - ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), - (pval_new[1], (pval_new[2],), [pval_new[3]]), false), - ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], - [pval_new[1], (pval_new[2],), [pval_new[3]]], false) -] - getter = getp(prob, sym) - setter! = setp(prob, sym) - if check_inference - @inferred getter(prob) - end - @test getter(prob) == oldval - if check_inference - @inferred setter!(prob, newval) - else - setter!(prob, newval) - end - @test getter(prob) == newval - setter!(prob, oldval) - @test getter(prob) == oldval -end - -# accessing parameters -@variables x(t) -@parameters tau - -@named fol = ODESystem([D(x) ~ (1 - x) / tau], t) -fol = complete(fol) -prob = ODEProblem(fol, [x => 0.0], (0.0, 10.0), [tau => 3.0]) -sol = solve(prob, Tsit5()) -@test getp(fol, tau)(sol) == 3 - -@testset "OptimizationSolution" begin - @variables begin - x, [bounds = (-2.0, 2.0)] - y, [bounds = (-1.0, 3.0)] - end - @parameters a=1 b=1 - loss = (a - x)^2 + b * (y - x^2)^2 - @named sys = OptimizationSystem(loss, [x, y], [a, b]) - sys = complete(sys) - u0 = [x => 1.0 - y => 2.0] - p = [a => 1.0 - b => 100.0] - prob = OptimizationProblem(sys, u0, p, grad = true, hess = true) - sol = solve(prob, GradientDescent()) - @test sol[x]≈1 atol=1e-3 - @test sol[y]≈1 atol=1e-3 - @test getp(sys, a)(sol) ≈ 1 - @test getp(sys, b)(sol) ≈ 100 - @test sol.ps[a] ≈ 1 - @test sol.ps[b] ≈ 100 -end - -# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2697 -@testset "Interpolation of derivative of observed variables" begin - @variables x(t) y(t) z(t) w(t)[1:2] - @named sys = ODESystem( - [D(x) ~ 1, y ~ x^2, z ~ 2y^2 + 3x, w[1] ~ x + y + z, w[2] ~ z * x * y], t) - sys = structural_simplify(sys) - prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0)) - sol = solve(prob, Tsit5()) - @test_throws ErrorException sol(1.0, Val{1}, idxs = y) - @test_throws ErrorException sol(1.0, Val{1}, idxs = [y, z]) - @test_throws ErrorException sol(1.0, Val{1}, idxs = w) - @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w]) - @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y]) -end diff --git a/test/runtests.jl b/test/runtests.jl index 0172f4a66..d1dddaea8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -119,7 +119,7 @@ end include("downstream/modelingtoolkit_remake.jl") end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin - include("downstream/symbol_indexing.jl") + include("downstream/comprehensive_indexing.jl") end if VERSION >= v"1.8" @time @safetestset "Symbol and integer based indexing of integrators" begin From de74ba08d843a3f898ce4d8fa8e58513582f9237 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 24 Jun 2024 13:12:46 +0530 Subject: [PATCH 4/5] ci: set `--depwarn=yes` for tests --- .github/workflows/Downstream.yml | 2 +- .github/workflows/Tests.yml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 4bffe0379..cdb56c4e9 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -62,7 +62,7 @@ jobs: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream - name: Load this and run the downstream tests - shell: julia --color=yes --project=downstream {0} + shell: julia --color=yes --project=downstream --depwarn=yes {0} run: | using Pkg try diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 3a0f1012a..142a57726 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -34,6 +34,7 @@ jobs: - "Python" uses: "SciML/.github/.github/workflows/tests.yml@v1" with: + julia-runtest-depwarn: "yes" group: "${{ matrix.group }}" julia-version: "${{ matrix.version }}" secrets: "inherit" From 10f71ab42535c15b8ffeda8a9cff3b394da9a9f1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 26 Jun 2024 12:00:24 +0530 Subject: [PATCH 5/5] test: fix solution interface test --- test/downstream/solution_interface.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index c99f86504..3e9f44bad 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -79,8 +79,7 @@ eqs = [D(x) ~ σ * (y - x), @variables a(t) α(t) connections = [0 ~ lorenz1.x + lorenz2.y + a * γ, α ~ 2lorenz1.x + a * γ] -@named sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2]) -sys_simplified = complete(structural_simplify(sys)) +@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2]) u0 = [lorenz1.x => 1.0, lorenz1.y => 0.0, @@ -99,7 +98,7 @@ p = [lorenz1.σ => 10.0, γ => 2.0] tspan = (0.0, 100.0) -prob = ODEProblem(sys_simplified, u0, tspan, p) +prob = ODEProblem(sys, u0, tspan, p) sol = solve(prob, Rodas4()) @test_throws ArgumentError sol[x]