Skip to content

Commit

Permalink
Merge branch 'develop' into feature_1192_hardcode
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl authored Jul 3, 2023
2 parents 078edf4 + b8edbcd commit 00bc9f6
Show file tree
Hide file tree
Showing 17 changed files with 250 additions and 89 deletions.
7 changes: 7 additions & 0 deletions include/amici/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ enum class NonlinearSolverIteration {
newton = 2
};

/** Steady-state computation mode in steadyStateProblem */
enum class SteadyStateComputationMode {
newtonOnly,
integrationOnly,
integrateIfNewtonFails
};

/** Sensitivity computation mode in steadyStateProblem */
enum class SteadyStateSensitivityMode {
newtonOnly,
Expand Down
25 changes: 20 additions & 5 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,20 @@ class Model : public AbstractModel, public ModelDimensions {
*/
void setUnscaledInitialStateSensitivities(std::vector<realtype> const& sx0);

/**
* @brief Set the mode how steady state is computed in the steadystate
* simulation.
* @param mode Steadystate computation mode
*/
void setSteadyStateComputationMode(SteadyStateComputationMode mode);

/**
* @brief Gets the mode how steady state is computed in the steadystate
* simulation.
* @return Mode
*/
SteadyStateComputationMode getSteadyStateComputationMode() const;

/**
* @brief Set the mode how sensitivities are computed in the steadystate
* simulation.
Expand Down Expand Up @@ -1977,12 +1991,13 @@ class Model : public AbstractModel, public ModelDimensions {
/** maximal number of events to track */
int nmaxevent_{10};

/**
* flag indicating whether steadystate sensitivities are to be computed
* via FSA when steadyStateSimulation is used
*/
/** method for steady-state computation */
SteadyStateComputationMode steadystate_computation_mode_{
SteadyStateComputationMode::integrateIfNewtonFails};

/** method for steadystate sensitivities computation */
SteadyStateSensitivityMode steadystate_sensitivity_mode_{
SteadyStateSensitivityMode::newtonOnly};
SteadyStateSensitivityMode::integrateIfNewtonFails};

/**
* Indicates whether the result of every call to `Model::f*` should be
Expand Down
4 changes: 2 additions & 2 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,14 @@ class ReturnData : public ModelDimensions {
/** initial state (shape `nx`) */
std::vector<realtype> x0;

/** preequilibration steady state found by Newton solver (shape `nx`) */
/** preequilibration steady state (shape `nx`) */
std::vector<realtype> x_ss;

/** initial sensitivities (shape `nplist` x `nx`, row-major) */
std::vector<realtype> sx0;

/**
* preequilibration sensitivities found by Newton solver
* preequilibration sensitivities
* (shape `nplist` x `nx`, row-major)
*/
std::vector<realtype> sx_ss;
Expand Down
2 changes: 2 additions & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ void serialize(Archive& ar, amici::Model& m, unsigned int const /*version*/) {
ar& m.pythonGenerated;
ar& m.min_sigma_;
ar& m.sigma_res_;
ar& m.steadystate_computation_mode_;
ar& m.steadystate_sensitivity_mode_;
}

/**
Expand Down
3 changes: 2 additions & 1 deletion python/examples/example_errors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@
"amici_solver = amici_model.getSolver()\n",
"amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)\n",
"amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)\n",
"amici_model.setSteadyStateSensitivityMode(amici.SteadyStateSensitivityMode.newtonOnly)\n",
"\n",
"np.random.seed(2020)\n",
"problem_parameters = dict(\n",
Expand Down Expand Up @@ -960,7 +961,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
38 changes: 24 additions & 14 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement
from .de_model import *
from .import_utils import (
amici_time_symbol,
ObservableTransformation,
SBMLException,
amici_time_symbol,
generate_flux_symbol,
smart_subs_dict,
strip_pysb,
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def transform_dxdt_to_concentration(species_id, dxdt):
for llh in si.symbols[SymbolId.LLHY].values()
)

self._process_sbml_rate_of(symbols)# substitute SBML-rateOf constructs
self._process_sbml_rate_of(symbols) # substitute SBML-rateOf constructs

def _process_sbml_rate_of(self, symbols) -> None:
"""Substitute any SBML-rateOf constructs in the model equations"""
Expand Down Expand Up @@ -1144,7 +1144,12 @@ def get_rate(symbol: sp.Symbol):
{rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
)

for component in chain(self.observables(), self.expressions(), self.events(), self._algebraic_equations):
for component in chain(
self.observables(),
self.expressions(),
self.events(),
self._algebraic_equations,
):
if rate_ofs := component.get_val().find(rate_of_func):
if isinstance(component, Event):
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
Expand Down Expand Up @@ -1181,7 +1186,6 @@ def get_rate(symbol: sp.Symbol):
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
# )


def add_component(
self, component: ModelQuantity, insert_first: Optional[bool] = False
) -> None:
Expand Down Expand Up @@ -2084,8 +2088,9 @@ def _compute_equation(self, name: str) -> None:

# need to check if equations are zero since we are using
# symbols
if not smart_is_zero_matrix(self.eq("stau")[ie]) \
and not smart_is_zero_matrix(self.eq("xdot")):
if not smart_is_zero_matrix(
self.eq("stau")[ie]
) and not smart_is_zero_matrix(self.eq("xdot")):
tmp_eq += smart_multiply(
self.sym("xdot_old") - self.sym("xdot"),
self.sym("stau").T,
Expand All @@ -2108,7 +2113,9 @@ def _compute_equation(self, name: str) -> None:
)

# additional part of chain rule state variables
tmp_dxdp += smart_multiply(self.sym("xdot_old"), self.sym("stau").T)
tmp_dxdp += smart_multiply(
self.sym("xdot_old"), self.sym("stau").T
)

# finish chain rule for the state variables
tmp_eq += smart_multiply(self.eq("ddeltaxdx")[ie], tmp_dxdp)
Expand Down Expand Up @@ -2839,11 +2846,14 @@ def _generate_c_code(self) -> None:
# only generate for those that have nontrivial implementation,
# check for both basic variables (not in functions) and function
# computed values
if ((
name in self.functions
and not self.functions[name].body
and name not in nobody_functions
) or name not in self.functions) and len(self.model.sym(name)) == 0:
if (
(
name in self.functions
and not self.functions[name].body
and name not in nobody_functions
)
or name not in self.functions
) and len(self.model.sym(name)) == 0:
continue
self._write_index_files(name)

Expand Down Expand Up @@ -3064,8 +3074,8 @@ def _write_function_file(self, function: str) -> None:
iszero = len(self.model.sym(sym)) == 0

if iszero and not (
(sym == "y" and "Jy" in function)
or (sym == "w" and "xdot" in function and len(self.model.sym(sym)))
(sym == "y" and "Jy" in function)
or (sym == "w" and "xdot" in function and len(self.model.sym(sym)))
):
continue

Expand Down
12 changes: 3 additions & 9 deletions python/sdist/amici/petab_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def get_states_in_condition_table(
raise NotImplementedError()

species_check_funs = {
MODEL_TYPE_SBML: lambda x: _element_is_sbml_state(
petab_problem.sbml_model, x
),
MODEL_TYPE_SBML: lambda x: _element_is_sbml_state(petab_problem.sbml_model, x),
MODEL_TYPE_PYSB: lambda x: _element_is_pysb_pattern(
petab_problem.model.model, x
),
Expand All @@ -38,9 +36,7 @@ def get_states_in_condition_table(
resolve_mapping(petab_problem.mapping_df, col): (None, None)
if condition is None
else (
petab_problem.condition_df.loc[
condition[SIMULATION_CONDITION_ID], col
],
petab_problem.condition_df.loc[condition[SIMULATION_CONDITION_ID], col],
petab_problem.condition_df.loc[
condition[PREEQUILIBRATION_CONDITION_ID], col
]
Expand All @@ -64,9 +60,7 @@ def get_states_in_condition_table(
pysb.bng.generate_equations(petab_problem.model.model)

try:
spm = pysb.pattern.SpeciesPatternMatcher(
model=petab_problem.model.model
)
spm = pysb.pattern.SpeciesPatternMatcher(model=petab_problem.model.model)
except NotImplementedError as e:
raise NotImplementedError(
"Requires https://github.com/pysb/pysb/pull/570. "
Expand Down
36 changes: 20 additions & 16 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,13 +861,13 @@ def _process_species_initial(self):

# don't assign this since they need to stay in order
sorted_species = toposort_symbols(self.symbols[SymbolId.SPECIES], "init")
for species, rateof_dummies in zip(self.symbols[SymbolId.SPECIES].values(), all_rateof_dummies):
for species, rateof_dummies in zip(
self.symbols[SymbolId.SPECIES].values(), all_rateof_dummies
):
species["init"] = _dummy_to_rateof(
smart_subs_dict(species["init"], sorted_species, "init"),
rateof_dummies
smart_subs_dict(species["init"], sorted_species, "init"), rateof_dummies
)


@log_execution_time("processing SBML rate rules", logger)
def _process_rate_rules(self):
"""
Expand Down Expand Up @@ -1058,14 +1058,14 @@ def _process_parameters(
# so far, this concerns parameters with initial assignments containing rateOf(.)
# (those have been skipped above)
for par in self.sbml.getListOfParameters():
if (ia := self._get_element_initial_assignment(par.getId())) is not None \
and ia.find(sp.core.function.UndefinedFunction("rateOf")):
if (
ia := self._get_element_initial_assignment(par.getId())
) is not None and ia.find(sp.core.function.UndefinedFunction("rateOf")):
self.symbols[SymbolId.EXPRESSION][_get_identifier_symbol(par)] = {
"name": par.getName() if par.isSetName() else par.getId(),
"value": ia,
}


@log_execution_time("processing SBML reactions", logger)
def _process_reactions(self):
"""
Expand Down Expand Up @@ -2616,11 +2616,14 @@ def _get_list_of_species_references(
ListOfSpeciesReferences
"""
return [
reference
for reaction in sbml_model.getListOfReactions()
for reference in
itt.chain(reaction.getListOfReactants(), reaction.getListOfProducts(), reaction.getListOfModifiers())
]
reference
for reaction in sbml_model.getListOfReactions()
for reference in itt.chain(
reaction.getListOfReactants(),
reaction.getListOfProducts(),
reaction.getListOfModifiers(),
)
]


def replace_logx(math_str: Union[str, float, None]) -> Union[str, float, None]:
Expand Down Expand Up @@ -2755,11 +2758,12 @@ def _rateof_to_dummy(sym_math):
[...substitute...]
sym_math = _dummy_to_rateof(sym_math, rateof_to_dummy)
"""
if rate_ofs := sym_math.find(
sp.core.function.UndefinedFunction("rateOf")
):
if rate_ofs := sym_math.find(sp.core.function.UndefinedFunction("rateOf")):
# replace by dummies to avoid species substitution
rateof_dummies = {rate_of: sp.Dummy(f"Dummy_RateOf_{rate_of.args[0].name}") for rate_of in rate_ofs}
rateof_dummies = {
rate_of: sp.Dummy(f"Dummy_RateOf_{rate_of.args[0].name}")
for rate_of in rate_ofs
}

return sym_math.subs(rateof_dummies), rateof_dummies
return sym_math, {}
Expand Down
20 changes: 13 additions & 7 deletions python/sdist/amici/swig_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Convenience wrappers for the swig interface"""
import logging
import sys

import warnings
from contextlib import contextmanager, suppress
from typing import Any, Dict, List, Optional, Sequence, Union
Expand Down Expand Up @@ -106,8 +105,7 @@ def runAmiciSimulation(
"""
if (
model.ne > 0
and solver.getSensitivityMethod()
== amici_swig.SensitivityMethod.adjoint
and solver.getSensitivityMethod() == amici_swig.SensitivityMethod.adjoint
and solver.getSensitivityOrder() == amici_swig.SensitivityOrder.first
):
warnings.warn(
Expand Down Expand Up @@ -168,8 +166,7 @@ def runAmiciSimulations(
"""
if (
model.ne > 0
and solver.getSensitivityMethod()
== amici_swig.SensitivityMethod.adjoint
and solver.getSensitivityMethod() == amici_swig.SensitivityMethod.adjoint
and solver.getSensitivityOrder() == amici_swig.SensitivityOrder.first
):
warnings.warn(
Expand All @@ -181,7 +178,11 @@ def runAmiciSimulations(
with _capture_cstdout():
edata_ptr_vector = amici_swig.ExpDataPtrVector(edata_list)
rdata_ptr_list = amici_swig.runAmiciSimulations(
_get_ptr(solver), edata_ptr_vector, _get_ptr(model), failfast, num_threads
_get_ptr(solver),
edata_ptr_vector,
_get_ptr(model),
failfast,
num_threads,
)
for rdata in rdata_ptr_list:
_log_simulation(rdata)
Expand Down Expand Up @@ -240,6 +241,7 @@ def writeSolverSettingsToHDF5(
"ReinitializationStateIdxs",
"ReinitializeFixedParameterInitialStates",
"StateIsNonNegative",
"SteadyStateComputationMode",
"SteadyStateSensitivityMode",
("t0", "setT0"),
"Timepoints",
Expand Down Expand Up @@ -318,6 +320,10 @@ def _ids_and_names_to_rdata(rdata: amici_swig.ReturnData, model: amici_swig.Mode
):
for name_or_id in ("Ids", "Names"):
names_or_ids = getattr(model, f"get{entity_type}{name_or_id}")()
setattr(rdata, f"{entity_type.lower()}_{name_or_id.lower()}", names_or_ids)
setattr(
rdata,
f"{entity_type.lower()}_{name_or_id.lower()}",
names_or_ids,
)
rdata.state_ids_solver = model.getStateIdsSolver()
rdata.state_names_solver = model.getStateNamesSolver()
Loading

0 comments on commit 00bc9f6

Please sign in to comment.