Skip to content

Commit

Permalink
add solution interfacing
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jul 9, 2024
1 parent c70e716 commit 37a14d0
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/Catalyst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,16 @@ include("spatial_reaction_systems/spatial_reactions.jl")
export TransportReaction, TransportReactions, @transport_reaction
export isedgeparameter

# Lattice reaction systems
# Lattice reaction systems.
include("spatial_reaction_systems/lattice_reaction_systems.jl")
export LatticeReactionSystem
export spatial_species, vertex_parameters, edge_parameters
export CartesianGrid, CartesianGridReJ # (Implemented in JumpProcesses)
export has_cartesian_lattice, has_masked_lattice, has_grid_lattice, has_graph_lattice,
grid_dims, grid_size
export make_edge_p_values, make_directed_edge_values
include("spatial_reaction_systems/lattice_solution_interfacing.jl")
export get_lrs_vals

# Specific spatial problem types.
include("spatial_reaction_systems/spatial_ODE_systems.jl")
Expand Down
14 changes: 13 additions & 1 deletion src/spatial_reaction_systems/lattice_jump_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ end

### Extra ###

# Temporary. Awaiting implementation in SII, or proper implementation withinCatalyst (with more general functionality).
# Temporary. Awaiting implementation in SII, or proper implementation within Catalyst (with
# more general functionality).
function int_map(map_in, sys)
return [ModelingToolkit.variable_index(sys, pair[1]) => pair[2] for pair in map_in]
end
Expand All @@ -141,3 +142,14 @@ end
# majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
# end


### Problem & Integrator Rebuilding ###

# Currently not implemented.
function rebuild_lat_internals!(dprob::DiscreteProblem)
error("Modification and/or rebuilding of `DiscreteProblem`s is currently not supported. Please create a new problem instead.")
end
function rebuild_lat_internals!(jprob::JumpProblem)
error("Modification and/or rebuilding of `JumpProblem`s is currently not supported. Please create a new problem instead.")
end
121 changes: 121 additions & 0 deletions src/spatial_reaction_systems/lattice_solution_interfacing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
### Rudimentary Interfacing Function ###
# A single function, `get_lrs_vals`, which contain all interfacing functionality. However,
# long-term it should be replaced with a sleeker interface. Ideally as MTK-wider support for
# lattice problems and solutions are introduced.

"""
get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
A function for retrieving the solution of a `LatticeReactionSystem`-based simulation on various
desired forms. Generally, for `LatticeReactionSystem`s, the values in `sol` is ordered in a
way which is not directly interpretable by the user. Furthermore, the normal Catalyst interface
for solutions (e.g. `sol[:X]`) does not work for these solutions. Hence this function is used instead.
The output is a vector, which in each position contain sp's value (either at a time step of time,
depending on the input `t`). Its shape depends on the lattice (using a similar form as heterogeneous
initial conditions). I.e. for a NxM cartesian grid, the values are NxM matrices. For a masked grid,
the values are sparse matrices. For a graph lattice, the values are vectors (where the value in
the n'th position corresponds to sp's value in the n'th vertex).
Arguments:
- `sol`: The solution from which we wish to retrieve some values.
- `sp`: The species which values we wish to retrieve. Can be either a symbol (e.g. `:X`) or a symbolic
variable (e.g. `X`).
- `lrs`: The `LatticeReactionSystem` which was simulated to generate the solution.
- `t = nothing`: If `nothing`, we simply returns the solution across all saved timesteps. If `t`
instead is a vector (or range of values), returns the solutions interpolated at these timepoints.
Notes:
- The `get_lrs_vals` is not optimised for performance. However, it should still be quite performant,
but there might be some limitations if called a very large number of times.
- Long-term it is likely that this function gets replaced with a sleeker interface.
Example:
```julia
using Catalyst, OrdinaryDiffEq
# Prepare `LatticeReactionSystem`s.
rs = @reaction_network begin
(k1,k2), X1 <--> X2
end
tr = @transport_reaction D X1
lrs = LatticeReactionSystem(rs, [tr], CartesianGrid((2,2)))
# Create problems.
u0 = [:X1 => 1, :X2 => 2]
tspan = (0.0, 10.0)
ps = [:k1 => 1, :k2 => 2.0, :D => 0.1]
oprob = ODEProblem(lrs1, u0, tspan, ps)
osol = solve(oprob1, Tsit5())
get_lrs_vals(osol, :X1, lrs) # Returns the value of X1 at each timestep.
get_lrs_vals(osol, :X1, lrs; t = 0.0:10.0) # Returns the value of X1 at times 0.0, 1.0, ..., 10.0
```
"""
function get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
# Figures out which species we wish to fetch information about.
(sp isa Symbol) && (sp = Catalyst._symbol_to_var(lrs, sp))
sp_idx = findfirst(isequal(sp), species(lrs))
sp_tot = length(species(lrs))

# Extracts the lattice and calls the next function. Masked grids (Array of Bools) are converted
# to sparse array using the same template size as we wish to shape the data to.
lattice = Catalyst.lattice(lrs)
if has_masked_lattice(lrs)
if grid_dims(lrs) == 3
error("The `get_lrs_vals` function is not defined for systems based on 3d sparse arrays. Please raise an issue at the Catalyst GitHub site if this is something which would be useful to you.")
end
lattice = sparse(lattice)
end
get_lrs_vals(sol, lattice, t, sp_idx, sp_tot)
end

# Function which handles the input in the case where `t` is `nothing` (i.e. return `sp`s value
# across all sample points).
function get_lrs_vals(sol, lattice, t::Nothing, sp_idx, sp_tot)
# ODE simulations contain, in each data point, all values in a single vector. Jump simulations
# instead in a matrix (NxM, where N is the number of species and M the number of vertices). We
# must consider each case separately.
if sol.prob isa ODEProblem
return [reshape_vals(vals[sp_idx:sp_tot:end], lattice) for vals in sol.u]
elseif sol.prob isa DiscreteProblem
return [reshape_vals(vals[sp_idx,:], lattice) for vals in sol.u]
else
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
end
end

# Function which handles the input in the case where `t` is a range of values (i.e. return `sp`s
# value at all designated time points.
function get_lrs_vals(sol, lattice, t::AbstractVector{T}, sp_idx, sp_tot) where {T <: Number}
if (minimum(t) < sol.t[1]) || (maximum(t) > sol.t[end])
error("The range of the t values provided for sampling, ($(minimum(t)),$(maximum(t))) is not fully within the range of the simulation time span ($(sol.t[1]),$(sol.t[end])).")
end

# ODE simulations contain, in each data point, all values in a single vector. Jump simulations
# instead in a matrix (NxM, where N is the number of species and M the number of vertices). We
# must consider each case separately.
if sol.prob isa ODEProblem
return [reshape_vals(sol(ti)[sp_idx:sp_tot:end], lattice) for ti in t]
elseif sol.prob isa DiscreteProblem
return [reshape_vals(sol(ti)[sp_idx,:], lattice) for ti in t]
else
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
end
end

# Functions which in each sample point reshapes the vector of values to the correct form (depending
# on the type of lattice used).
function reshape_vals(vals, lattice::CartesianGridRej{N, T}) where {N,T}
return reshape(vals, lattice.dims...)
end
function reshape_vals(vals, lattice::AbstractSparseArray{Bool, Int64, 1})
return SparseVector(lattice.n, lattice.nzind, vals)
end
function reshape_vals(vals, lattice::AbstractSparseArray{Bool, Int64, 2})
return SparseMatrixCSC(lattice.m, lattice.n, lattice.colptr, lattice.rowval, vals)
end
function reshape_vals(vals, lattice::DiGraph)
return vals
end

51 changes: 46 additions & 5 deletions src/spatial_reaction_systems/spatial_ODE_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R, V
jac_transport, transport_rates)
J = (jac ? f : nothing)

# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
syms = MT.getname.(species(lrs))
paramsyms = MT.getname.(parameters(lrs))
return ODEFunction(f; jac = J, jac_prototype, syms, paramsyms)
# Extracts the `Symbol` form for parameters (but not species). Creates and returns the `ODEFunction`.
paramsyms = [MT.getname(p) for p in parameters(lrs)]
sys = SciMLBase.SymbolCache([], paramsyms, [])
return ODEFunction(f; jac = J, jac_prototype, sys)
end

# Builds a jacobian prototype.
Expand Down Expand Up @@ -325,7 +325,48 @@ end

### Functor Updating Functionality ###

# Function for rebuilding a `LatticeReactionSystem` `ODEProblem` after it has been updated.
"""
rebuild_lat_internals!(sciml_struct)
Rebuilds the internal functions for simulating a LatticeReactionSystem. WHenever a problem or
integrator have had its parameter values updated, thus function should be called for the update to
be taken into account. For ODE simulations, `rebuild_lat_internals!` needs only to be called when
- An edge parameter have been updated.
- When a parameter with spatially homogeneous values have been given spatially heterogeneous values
(or vice versa).
Arguments:
- `sciml_struct`: The problem (e.g. an `ODEProblem`) or an integrator which we wish to rebuild.
Notes:
- Currently does not work for `DiscreteProblem`s, `JumpProblem`s, or their integrators.
- The function is not build with performance in mind, so avoid calling it multiple times in
performance-critical applications.
Example:
```julia
# Creates an initial `ODEProblem`
rs = @reaction_network begin
(k1,k2), X1 <--> X2
end
tr = @transport_reaction D X1
grid = CartesianGrid((2,2))
lrs = LatticeReactionSystem(rs, [tr], grid)
u0 = [:X1 => 2, :X2 => [5 6; 7 8]]
tspan = (0.0, 10.0)
ps = [:k1 => 1.5, :k2 => [1.0 1.5; 2.0 3.5], :D => 0.1]
oprob = ODEProblem(lrs, u0, tspan, ps)
# Updates parameter values.
oprob.ps[:ks] = [2.0 2.5; 3.0 4.5]
oprob.ps[:D] = 0.05
# Rebuilds `ODEProblem` to make changes have an effect.
rebuild_lat_internals!(oprob)
```
"""
function rebuild_lat_internals!(oprob::ODEProblem)
rebuild_lat_internals!(oprob.f.f, oprob.p, oprob.f.f.lrs)
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,6 @@ using SafeTestsets, Test
@time @safetestset "Spatial Lattice Variants" begin include("spatial_modelling/lattice_reaction_systems_lattice_types.jl") end
@time @safetestset "ODE Lattice Systems Simulations" begin include("spatial_modelling/lattice_reaction_systems_ODEs.jl") end
@time @safetestset "Jump Lattice Systems Simulations" begin include("spatial_modelling/lattice_reaction_systems_jumps.jl") end
@time @safetestset "Jump Solution Interfacing" begin include("spatial_modelling/lattice_solution_interfacing.jl") end

end # @time
24 changes: 22 additions & 2 deletions test/spatial_modelling/lattice_reaction_systems_jumps.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
### Preparations ###

# Fetch packages.
using JumpProcesses
using Random, Statistics, SparseArrays, Test
using JumpProcesses, Statistics, SparseArrays, Test

# Fetch test networks.
include("../spatial_test_networks.jl")
Expand Down Expand Up @@ -204,6 +203,27 @@ end

### JumpProblem & Integrator Interfacing ###

# Currently not supported, check that corresponding functions yields errors.
let
# Prepare `LatticeReactionSystem`.
rs = @reaction_network begin
(k1,k2), X1 <--> X2
end
tr = @transport_reaction D X1
grid = CartesianGrid((2,2))
lrs = LatticeReactionSystem(rs, [tr], grid)

# Create problems.
u0 = [:X1 => 2, :X2 => [5 6; 7 8]]
tspan = (0.0, 10.0)
ps = [:k1 => 1.5, :k2 => [1.0 1.5; 2.0 3.5], :D => 0.1]
dprob = DiscreteProblem(lrs, u0, tspan, ps)
jprob = JumpProblem(lrs, dprob, NSM())

# Checks that rebuilding errors.
@test_throws Exception rebuild_lat_internals!(dprob)
@test_throws Exception rebuild_lat_internals!(jprob)
end

### Other Tests ###

Expand Down
Loading

0 comments on commit 37a14d0

Please sign in to comment.