Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: error when interpolating derivatives of observed variables #688

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@
dense, tslocation, stats, alg_choice, retcode, resid, original)
end

error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
function error_if_observed_derivative(sys, idx, ::Type)
if symbolic_type(idx) != NotSymbolic() && is_observed(sys, idx) ||
symbolic_type(idx) == NotSymbolic() && any(x -> is_observed(sys, x), idx)
error("""

Check warning on line 153 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L149-L153

Added lines #L149 - L153 were not covered by tests
Cannot interpolate derivatives of observed variables. A possible solution could be
interpolating the symbolic expression that evaluates to the derivative of the
observed variable or using DataInterpolations.jl.
""")
end
end

function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
sol(t, deriv, idxs, continuity)
Expand Down Expand Up @@ -197,6 +209,7 @@
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 212 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L212

Added line #L212 was not covered by tests
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
Expand All @@ -206,15 +219,19 @@

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
if symbolic_type(idxs) == NotSymbolic() &&

Check warning on line 222 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L222

Added line #L222 was not covered by tests
any(isequal(NotSymbolic()), symbolic_type.(idxs))
error("Incorrect specification of `idxs`")
end
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 226 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L226

Added line #L226 was not covered by tests
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
first(interp_sol[idxs])
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 234 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L234

Added line #L234 was not covered by tests
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
Expand All @@ -228,6 +245,7 @@
idxs::AbstractVector, continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 248 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L248

Added line #L248 was not covered by tests
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
indexed_sol = interp_sol[idxs]
Expand Down
17 changes: 17 additions & 0 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ sol = solve(prob, Tsit5())
@test sol[x] isa Vector{<:Vector}
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
@test sol.ps[p] == [1, 2, 3]
# interpolation of array variables
@test sol(1.0, idxs = x) == [sol(1.0, idxs = x[i]) for i in 1:3]

x_idx = variable_index.((sys,), [x[1], x[2], x[3]])
y_idx = variable_index(sys, y)
Expand Down Expand Up @@ -369,3 +371,18 @@ sol = solve(prob, Tsit5())
@test sol.ps[a] ≈ 1
@test sol.ps[b] ≈ 100
end

# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2697
@testset "Interpolation of derivative of observed variables" begin
@variables x(t) y(t) z(t) w(t)[1:2]
@named sys = ODESystem(
[D(x) ~ 1, y ~ x^2, z ~ 2y^2 + 3x, w[1] ~ x + y + z, w[2] ~ z * x * y], t)
sys = structural_simplify(sys)
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0))
sol = solve(prob, Tsit5())
@test_throws ErrorException sol(1.0, Val{1}, idxs = y)
@test_throws ErrorException sol(1.0, Val{1}, idxs = [y, z])
@test_throws ErrorException sol(1.0, Val{1}, idxs = w)
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w])
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y])
end
Loading