From 492ae1dfc0edb9b0bfbe39c29abc1d1541ea4f2b Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 28 Jun 2023 22:21:45 +0200 Subject: [PATCH 1/3] Allow hardcoding of numerical values --- python/sdist/amici/sbml_import.py | 75 ++++++++++++++++++++++++++----- python/tests/test_sbml_import.py | 33 +++++++++++++- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 25b54f5c93..8ebbe4fdd3 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -13,7 +13,18 @@ import warnings import xml.etree.ElementTree as ET from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import libsbml as sbml import numpy as np @@ -281,6 +292,7 @@ def sbml2amici( cache_simplify: bool = False, log_as_log10: bool = True, generate_sensitivity_code: bool = True, + hardcode_symbols: Sequence[str] = None, ) -> None: """ Generate and compile AMICI C++ files for the model provided to the @@ -385,6 +397,12 @@ def sbml2amici( :param generate_sensitivity_code: If ``False``, the code required for sensitivity computation will not be generated + + :param hardcode_symbols: + List of SBML entitiy IDs that are to be hardcoded in the generated model. + Their values cannot be changed anymore after model import. + Currently only parameters that are not targets of rules or + initial assignments are supported. """ set_log_level(logger, verbose) @@ -401,6 +419,7 @@ def sbml2amici( simplify=simplify, cache_simplify=cache_simplify, log_as_log10=log_as_log10, + hardcode_symbols=hardcode_symbols, ) exporter = DEExporter( @@ -437,6 +456,7 @@ def _build_ode_model( simplify: Optional[Callable] = _default_simplify, cache_simplify: bool = False, log_as_log10: bool = True, + hardcode_symbols: Sequence[str] = None, ) -> DEModel: """Generate an ODEModel from this SBML model. @@ -444,6 +464,13 @@ def _build_ode_model( """ constant_parameters = list(constant_parameters) if constant_parameters else [] + hardcode_symbols = set(hardcode_symbols) if hardcode_symbols else {} + if invalid := (set(constant_parameters) & set(hardcode_symbols)): + raise ValueError( + "The following parameters were selected as both constant " + f"and hard-coded which is not allowed: {invalid}" + ) + if sigmas is None: sigmas = {} @@ -460,7 +487,9 @@ def _build_ode_model( self.sbml_parser_settings.setParseLog( sbml.L3P_PARSE_LOG_AS_LOG10 if log_as_log10 else sbml.L3P_PARSE_LOG_AS_LN ) - self._process_sbml(constant_parameters) + self._process_sbml( + constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols + ) if ( self.symbols.get(SymbolId.EVENT, False) @@ -496,18 +525,26 @@ def _build_ode_model( return ode_model @log_execution_time("importing SBML", logger) - def _process_sbml(self, constant_parameters: List[str] = None) -> None: + def _process_sbml( + self, + constant_parameters: List[str] = None, + hardcode_symbols: Sequence[str] = None, + ) -> None: """ Read parameters, species, reactions, and so on from SBML model :param constant_parameters: SBML Ids identifying constant parameters + :param hardcode_parameters: + Parameter IDs to be replaced by their values in the generated model. """ if not self._discard_annotations: self._process_annotations() self.check_support() - self._gather_locals() - self._process_parameters(constant_parameters) + self._gather_locals(hardcode_symbols=hardcode_symbols) + self._process_parameters( + constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols + ) self._process_compartments() self._process_species() self._process_reactions() @@ -639,7 +676,7 @@ def check_event_support(self) -> None: ) @log_execution_time("gathering local SBML symbols", logger) - def _gather_locals(self) -> None: + def _gather_locals(self, hardcode_symbols: Sequence[str] = None) -> None: """ Populate self.local_symbols with all model entities. @@ -647,10 +684,10 @@ def _gather_locals(self) -> None: shadowing model entities as well as to avoid possibly costly symbolic substitutions """ - self._gather_base_locals() + self._gather_base_locals(hardcode_symbols=hardcode_symbols) self._gather_dependent_locals() - def _gather_base_locals(self): + def _gather_base_locals(self, hardcode_symbols: Sequence[str] = None) -> None: """ Populate self.local_symbols with pure symbol definitions that do not depend on any other symbol. @@ -677,8 +714,19 @@ def _gather_base_locals(self): ): if not c.isSetId(): continue - - self.add_local_symbol(c.getId(), _get_identifier_symbol(c)) + if c.getId() in hardcode_symbols: + if self.sbml.getRuleByVariable(c.getId()) is not None: + raise ValueError( + f"Cannot hardcode symbol `{c.getId()}` that is a rule target." + ) + if self.sbml.getInitialAssignment(c.getId()): + raise NotImplementedError( + f"Cannot hardcode symbol `{c.getId()}` " + "that is an initial assignment target." + ) + self.add_local_symbol(c.getId(), sp.Float(c.getValue())) + else: + self.add_local_symbol(c.getId(), _get_identifier_symbol(c)) for x_ref in _get_list_of_species_references(self.sbml): if not x_ref.isSetId(): @@ -940,7 +988,11 @@ def _process_annotations(self) -> None: self.sbml.removeParameter(parameter_id) @log_execution_time("processing SBML parameters", logger) - def _process_parameters(self, constant_parameters: List[str] = None) -> None: + def _process_parameters( + self, + constant_parameters: List[str] = None, + hardcode_symbols: Sequence[str] = None, + ) -> None: """ Get parameter information from SBML model. @@ -983,6 +1035,7 @@ def _process_parameters(self, constant_parameters: List[str] = None) -> None: if parameter.getId() not in constant_parameters and self._get_element_initial_assignment(parameter.getId()) is None and not self.is_assignment_rule_target(parameter) + and parameter.getId() not in hardcode_symbols ] loop_settings = { diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index d0ce9cae5c..09a19c6d4f 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -38,9 +38,13 @@ def simple_sbml_model(): model.addSpecies(s1) p1 = model.createParameter() p1.setId("p1") - p1.setValue(0.0) + p1.setValue(2.0) model.addParameter(p1) + r = model.createRateRule() + r.setVariable("S1") + r.setFormula("p1") + return document, model @@ -662,3 +666,30 @@ def test_code_gen_uses_lhs_symbol_ids(): ) dwdx = Path(tmpdir, "dwdx.cpp").read_text() assert "dobservable_x1_dx1 = " in dwdx + + +def test_hardcode_parameters(simple_sbml_model): + """Test model generation works for model without observables""" + sbml_doc, sbml_model = simple_sbml_model + sbml_importer = SbmlImporter(sbml_source=sbml_model, from_file=False) + + ode_model = sbml_importer._build_ode_model() + assert str(ode_model.parameters()) == "[p1]" + assert ode_model.differential_states()[0].get_dt().name == "p1" + + ode_model = sbml_importer._build_ode_model( + constant_parameters=[], + hardcode_symbols=["p1"], + ) + assert str(ode_model.parameters()) == "[]" + assert ( + ode_model.differential_states()[0].get_dt() + == sbml_model.getParameter("p1").getValue() + ) + + with pytest.raises(ValueError): + sbml_importer._build_ode_model( + # mutually exclusive + constant_parameters=["p1"], + hardcode_symbols=["p1"], + ) From d7ecff38320e7acb858be01cb5256989c8c32851 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 29 Jun 2023 08:54:49 +0200 Subject: [PATCH 2/3] .. --- python/tests/test_sbml_import.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 09a19c6d4f..41ccdd925c 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -41,10 +41,6 @@ def simple_sbml_model(): p1.setValue(2.0) model.addParameter(p1) - r = model.createRateRule() - r.setVariable("S1") - r.setFormula("p1") - return document, model @@ -672,6 +668,10 @@ def test_hardcode_parameters(simple_sbml_model): """Test model generation works for model without observables""" sbml_doc, sbml_model = simple_sbml_model sbml_importer = SbmlImporter(sbml_source=sbml_model, from_file=False) + r = sbml_model.createRateRule() + r.setVariable("S1") + r.setFormula("p1") + assert sbml_model.getParameter("p1").getValue() != 0 ode_model = sbml_importer._build_ode_model() assert str(ode_model.parameters()) == "[p1]" From 078edf44555017784d8bc1034bacef959d033e68 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 3 Jul 2023 07:38:04 +0200 Subject: [PATCH 3/3] constant --- python/sdist/amici/sbml_import.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 8ebbe4fdd3..0f6a4eb2be 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -715,9 +715,10 @@ def _gather_base_locals(self, hardcode_symbols: Sequence[str] = None) -> None: if not c.isSetId(): continue if c.getId() in hardcode_symbols: - if self.sbml.getRuleByVariable(c.getId()) is not None: + if c.getConstant() is not True: + # disallow anything that can be changed by rules/reaction/events raise ValueError( - f"Cannot hardcode symbol `{c.getId()}` that is a rule target." + f"Cannot hardcode non-constant symbol `{c.getId()}`." ) if self.sbml.getInitialAssignment(c.getId()): raise NotImplementedError(