Skip to content

Commit

Permalink
Merge pull request #721 from AayushSabharwal/as/getu-everywhere
Browse files Browse the repository at this point in the history
refactor: use `getu`/`setu` for all indexing
  • Loading branch information
ChrisRackauckas authored Jun 26, 2024
2 parents e3a0de8 + 10f71ab commit 2a3e88d
Show file tree
Hide file tree
Showing 12 changed files with 651 additions and 505 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --color=yes --project=downstream {0}
shell: julia --color=yes --project=downstream --depwarn=yes {0}
run: |
using Pkg
try
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- "Python"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
julia-runtest-depwarn: "yes"
group: "${{ matrix.group }}"
julia-version: "${{ matrix.version }}"
secrets: "inherit"
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand All @@ -106,11 +107,12 @@ PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
75 changes: 21 additions & 54 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,46 +465,20 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
end
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator,
::NotSymbolic,
I::Union{Int, AbstractArray{Int},
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)
else
error("Tried to index integrator with a Symbol that was not found in the system.")
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]
end

Base.@propagate_inbounds function _getindex(
A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
return getindex.((A,), sym)
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return _getindex(A, symtype, sym)
else
return _getindex(A, elsymtype, sym)
Base.@propagate_inbounds function Base.getindex(
A::DEIntegrator, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -522,25 +496,18 @@ function observed(A::DEIntegrator, sym)
end

function Base.setindex!(A::DEIntegrator, val, sym)
has_sys(A.f) ||
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
set_state!(A, val, variable_index(A, sym))
elseif is_parameter(A, sym)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A
else
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
end

function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
end

### Integrator traits
Expand Down
67 changes: 27 additions & 40 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,51 +38,38 @@ Base.@propagate_inbounds function Base.getindex(
return getindex(prob, all_variable_symbols(prob))
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob, sym)
return state_values(prob, variable_index(prob, sym))
elseif is_parameter(prob, sym)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")
elseif is_independent_variable(prob, sym)
return current_time(prob)
elseif is_observed(prob, sym)
obs = SymbolicIndexingInterface.observed(prob, sym)
if is_time_dependent(prob)
return obs(state_values(prob), parameter_values(prob), current_time(prob))
else
return obs(state_values(prob), parameter_values(prob))
end
else
error("Invalid indexing of problem: $sym is not a state, parameter, or independent variable")
end
elseif symbolic_type(sym) == ArraySymbolic()
return map(s -> prob[s], collect(sym))
else
sym isa AbstractArray || error("Invalid indexing of problem")
return map(s -> prob[s], sym)
Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
A::AbstractSciMLProblem, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
___internal_setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
end
function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob, sym)
set_state!(prob, val, variable_index(prob, sym))
elseif is_parameter(prob, sym)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
return prob
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((prob,), val, collect(sym))
return prob
else
sym isa AbstractArray || error("Invalid indexing of problem")
setindex!.((prob,), val, sym)
return prob

function ___internal_setindex!(A::AbstractSciMLProblem, val, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
end

function ___internal_setindex!(
A::AbstractSciMLProblem, val, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
end
8 changes: 7 additions & 1 deletion src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ function reinit!(cache::SciMLBase.AbstractOptimizationCache; p = missing,
return cache
end

SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache) = x.p
function SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache)
if has_reinit(x)
x.reinit_cache.p
else
x.p
end
end
SymbolicIndexingInterface.symbolic_container(x::AbstractOptimizationCache) = x.f

get_p(sol::OptimizationSolution) = sol.cache.p
Expand Down
27 changes: 12 additions & 15 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,19 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, :
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, parameter_values(A))
else
error("Tried to index solution with a Symbol that was not found in the system.")
end
elseif symbolic_type(sym) == ArraySymbolic()
return A[collect(sym)]
else
sym isa AbstractArray || error("Invalid indexing of solution")
return getindex.((A,), sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
A::AbstractNoTimeSolution, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand Down
4 changes: 4 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Expand All @@ -34,6 +37,7 @@ RecursiveArrayTools = "3"
SciMLBase = "2"
SciMLSensitivity = "7.11"
SciMLStructures = "1.1"
SteadyStateDiffEq = "2.2"
Sundials = "4.11"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "<1.6, 2"
Expand Down
Loading

0 comments on commit 2a3e88d

Please sign in to comment.