Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jul 10, 2024
1 parent 2cd8395 commit 6a69670
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/spatial_reaction_systems/lattice_jump_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function JumpProcesses.JumpProblem(lrs::LatticeReactionSystem, dprob, aggregator
# Creates and returns a spatial JumpProblem (masked lattices are not supported by these).
spatial_system = has_masked_lattice(lrs) ? get_lattice_graph(lrs) : lattice(lrs)
println(JumpProblem(non_spat_dprob, aggregator, sma_jumps;
hopping_constants, spatial_system, name, kwargs...).prob.u0)
hopping_constants, spatial_system, name, kwargs...).prob.u0)
return JumpProblem(non_spat_dprob, aggregator, sma_jumps;
hopping_constants, spatial_system, name, kwargs...)
end
Expand Down Expand Up @@ -143,4 +143,4 @@ end
# p = (non_spat_dprob.p isa DiffEqBase.NullParameters || non_spat_dprob.p === nothing) ? Num[] : non_spat_dprob.p
# majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
# end
# end
43 changes: 25 additions & 18 deletions src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,54 @@ end

function lat_setu!(oprob::ODEProblem, sp_idx::Int64, sp_tot::Int64, u, num_verts)
if length(u) == 1
foreach(idx -> (oprob.u0[sp_idx + (idx-1)*sp_tot] = u[1]), 1:num_verts)
foreach(idx -> (oprob.u0[sp_idx + (idx - 1) * sp_tot] = u[1]), 1:num_verts)
else
foreach(idx -> (oprob.u0[sp_idx + (idx-1)*sp_tot] = u[idx]), 1:num_verts)
foreach(idx -> (oprob.u0[sp_idx + (idx - 1) * sp_tot] = u[idx]), 1:num_verts)
end
end
function lat_setu!(jprob::JumpProblem, sp_idx::Int64, sp_tot::Int64, u, num_verts)
if length(u) == 1
foreach(idx -> (jprob.prob.u0[sp_idx,idx] = u[1]), 1:num_verts)
foreach(idx -> (jprob.prob.u0[sp_idx, idx] = u[1]), 1:num_verts)
else
foreach(idx -> (jprob.prob.u0[sp_idx,idx] = u[idx]), 1:num_verts)
foreach(idx -> (jprob.prob.u0[sp_idx, idx] = u[idx]), 1:num_verts)
end
end
function lat_setu!(oint::SciMLBase.AbstractODEIntegrator, sp_idx::Int64, sp_tot::Int64, u, num_verts)
function lat_setu!(oint::SciMLBase.AbstractODEIntegrator, sp_idx::Int64, sp_tot::Int64,
u, num_verts)
if length(u) == 1
foreach(idx -> (oint.u[sp_idx + (idx-1)*sp_tot] = u[1]), 1:num_verts)
foreach(idx -> (oint.u[sp_idx + (idx - 1) * sp_tot] = u[1]), 1:num_verts)
else
foreach(idx -> (oint.u[sp_idx + (idx-1)*sp_tot] = u[idx]), 1:num_verts)
foreach(idx -> (oint.u[sp_idx + (idx - 1) * sp_tot] = u[idx]), 1:num_verts)
end
end
function lat_setu!(jint::JumpProcesses.SSAIntegrator, sp_idx::Int64, sp_tot::Int64, u, num_verts)
function lat_setu!(
jint::JumpProcesses.SSAIntegrator, sp_idx::Int64, sp_tot::Int64, u, num_verts)
if length(u) == 1
foreach(idx -> (jint.u[sp_idx,idx] = u[1]), 1:num_verts)
foreach(idx -> (jint.u[sp_idx, idx] = u[1]), 1:num_verts)
else
foreach(idx -> (jint.u[sp_idx,idx] = u[idx]), 1:num_verts)
foreach(idx -> (jint.u[sp_idx, idx] = u[idx]), 1:num_verts)
end
end

function check_lattice_format(lattice::CartesianGridRej, u)
(u isa AbstractArray) || error("The input u should be an AbstractArray. It is a $(typeof(u)).")
(size(u) == lattice.dims) || error("The input u should have size $(lattice.dims), but has size $(size(u)).")
(u isa AbstractArray) ||
error("The input u should be an AbstractArray. It is a $(typeof(u)).")
(size(u) == lattice.dims) ||
error("The input u should have size $(lattice.dims), but has size $(size(u)).")
end
function check_lattice_format(lattice::AbstractSparseArray, u)
(u isa AbstractArray) || error("The input u should be an AbstractArray. It is a $(typeof(u)).")
(size(u) == size(lattice)) || error("The input u should have size $(size(lattice)), but has size $(size(u)).")
(u isa AbstractArray) ||
error("The input u should be an AbstractArray. It is a $(typeof(u)).")
(size(u) == size(lattice)) ||
error("The input u should have size $(size(lattice)), but has size $(size(u)).")
end
function check_lattice_format(lattice::DiGraph, u)
(u isa AbstractArray) || error("The input u should be an AbstractVector. It is a $(typeof(u)).")
(size(u) == size(lattice)) || error("The input u should have length $(nv(lattice)), but has length $(length(u)).")
(u isa AbstractArray) ||
error("The input u should be an AbstractVector. It is a $(typeof(u)).")
(length(u) == nv(lattice)) ||
error("The input u should have length $(nv(lattice)), but has length $(length(u)).")
end


"""
lat_getu(sim_struct, sp, lrs::LatticeReactionSystem)
Expand Down Expand Up @@ -235,7 +242,7 @@ end

# Function which handles the input in the case where `t` is a range of values (i.e. return `sp`s
# value at all designated time points.
function lat_getu(sol::ODESolution, lattice, t::AbstractVector{T}, sp_idx::Int64,
function lat_getu(sol::ODESolution, lattice, t::AbstractVector{T}, sp_idx::Int64,
sp_tot::Int64) where {T <: Number}
# Checks that an appropriate `t` is provided (however, DiffEq does permit out of range `t`s).
if (minimum(t) < sol.t[1]) || (maximum(t) > sol.t[end])
Expand Down
5 changes: 3 additions & 2 deletions test/spatial_modelling/lattice_solution_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ end

### Problem & Integrator `lat_getu` & `lat_setu!` Tests ###

# Checks `getu` for ODE and Jump problem and integrators.
# Checks `lat_getu` for ODE and Jump problem and integrators.
# Checks `lat_setu!` for ODE and Jump problem and integrators.
# Checks for all types of lattices.
# Checks for symbol and symbolic variables input.
let
Expand All @@ -88,7 +89,7 @@ let
for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph],[val0_cartesian, val0_masked, val0_graph])
# Prepares various problems and integrators. Uses `deepcopy` to ensure there is no cross-talk
# between the different u vectors as they get updated.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice_masked)
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice)
u0 = [:X => val0, :Y => 0.5]
ps = [:A => 1.0, :B => 2.0, :dX => 0.1]
oprob = ODEProblem(lrs, deepcopy(u0), (0.0, 1.0), ps)
Expand Down

0 comments on commit 6a69670

Please sign in to comment.