Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SBML import: Allow hardcoding of numerical values #2134

Merged
merged 5 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import libsbml as sbml
import numpy as np
Expand Down Expand Up @@ -281,6 +292,7 @@ def sbml2amici(
cache_simplify: bool = False,
log_as_log10: bool = True,
generate_sensitivity_code: bool = True,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Generate and compile AMICI C++ files for the model provided to the
Expand Down Expand Up @@ -385,6 +397,12 @@ def sbml2amici(
:param generate_sensitivity_code:
If ``False``, the code required for sensitivity computation will
not be generated

:param hardcode_symbols:
List of SBML entitiy IDs that are to be hardcoded in the generated model.
Their values cannot be changed anymore after model import.
Currently only parameters that are not targets of rules or
initial assignments are supported.
"""
set_log_level(logger, verbose)

Expand All @@ -401,6 +419,7 @@ def sbml2amici(
simplify=simplify,
cache_simplify=cache_simplify,
log_as_log10=log_as_log10,
hardcode_symbols=hardcode_symbols,
)

exporter = DEExporter(
Expand Down Expand Up @@ -437,13 +456,21 @@ def _build_ode_model(
simplify: Optional[Callable] = _default_simplify,
cache_simplify: bool = False,
log_as_log10: bool = True,
hardcode_symbols: Sequence[str] = None,
) -> DEModel:
"""Generate an ODEModel from this SBML model.

See :py:func:`sbml2amici` for parameters.
"""
constant_parameters = list(constant_parameters) if constant_parameters else []

hardcode_symbols = set(hardcode_symbols) if hardcode_symbols else {}
if invalid := (set(constant_parameters) & set(hardcode_symbols)):
raise ValueError(
"The following parameters were selected as both constant "
f"and hard-coded which is not allowed: {invalid}"
)

if sigmas is None:
sigmas = {}

Expand All @@ -460,7 +487,9 @@ def _build_ode_model(
self.sbml_parser_settings.setParseLog(
sbml.L3P_PARSE_LOG_AS_LOG10 if log_as_log10 else sbml.L3P_PARSE_LOG_AS_LN
)
self._process_sbml(constant_parameters)
self._process_sbml(
constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols
)

if (
self.symbols.get(SymbolId.EVENT, False)
Expand Down Expand Up @@ -496,18 +525,26 @@ def _build_ode_model(
return ode_model

@log_execution_time("importing SBML", logger)
def _process_sbml(self, constant_parameters: List[str] = None) -> None:
def _process_sbml(
self,
constant_parameters: List[str] = None,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Read parameters, species, reactions, and so on from SBML model

:param constant_parameters:
SBML Ids identifying constant parameters
:param hardcode_parameters:
Parameter IDs to be replaced by their values in the generated model.
"""
if not self._discard_annotations:
self._process_annotations()
self.check_support()
self._gather_locals()
self._process_parameters(constant_parameters)
self._gather_locals(hardcode_symbols=hardcode_symbols)
self._process_parameters(
constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols
)
self._process_compartments()
self._process_species()
self._process_reactions()
Expand Down Expand Up @@ -639,18 +676,18 @@ def check_event_support(self) -> None:
)

@log_execution_time("gathering local SBML symbols", logger)
def _gather_locals(self) -> None:
def _gather_locals(self, hardcode_symbols: Sequence[str] = None) -> None:
"""
Populate self.local_symbols with all model entities.

This is later used during sympifications to avoid sympy builtins
shadowing model entities as well as to avoid possibly costly
symbolic substitutions
"""
self._gather_base_locals()
self._gather_base_locals(hardcode_symbols=hardcode_symbols)
self._gather_dependent_locals()

def _gather_base_locals(self):
def _gather_base_locals(self, hardcode_symbols: Sequence[str] = None) -> None:
"""
Populate self.local_symbols with pure symbol definitions that do not
depend on any other symbol.
Expand All @@ -677,8 +714,19 @@ def _gather_base_locals(self):
):
if not c.isSetId():
continue

self.add_local_symbol(c.getId(), _get_identifier_symbol(c))
if c.getId() in hardcode_symbols:
if self.sbml.getRuleByVariable(c.getId()) is not None:
raise ValueError(
f"Cannot hardcode symbol `{c.getId()}` that is a rule target."
)
if self.sbml.getInitialAssignment(c.getId()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add event assignements. What about assignment rules?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Event assignments were missing, true. Assignment rules were handled by getRuleByVariable. Now changed to checking for the constant attribute which covers both.

raise NotImplementedError(
f"Cannot hardcode symbol `{c.getId()}` "
"that is an initial assignment target."
)
self.add_local_symbol(c.getId(), sp.Float(c.getValue()))
else:
self.add_local_symbol(c.getId(), _get_identifier_symbol(c))

for x_ref in _get_list_of_species_references(self.sbml):
if not x_ref.isSetId():
Expand Down Expand Up @@ -940,7 +988,11 @@ def _process_annotations(self) -> None:
self.sbml.removeParameter(parameter_id)

@log_execution_time("processing SBML parameters", logger)
def _process_parameters(self, constant_parameters: List[str] = None) -> None:
def _process_parameters(
self,
constant_parameters: List[str] = None,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Get parameter information from SBML model.

Expand Down Expand Up @@ -983,6 +1035,7 @@ def _process_parameters(self, constant_parameters: List[str] = None) -> None:
if parameter.getId() not in constant_parameters
and self._get_element_initial_assignment(parameter.getId()) is None
and not self.is_assignment_rule_target(parameter)
and parameter.getId() not in hardcode_symbols
]

loop_settings = {
Expand Down
33 changes: 32 additions & 1 deletion python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def simple_sbml_model():
model.addSpecies(s1)
p1 = model.createParameter()
p1.setId("p1")
p1.setValue(0.0)
p1.setValue(2.0)
model.addParameter(p1)

return document, model
Expand Down Expand Up @@ -662,3 +662,34 @@ def test_code_gen_uses_lhs_symbol_ids():
)
dwdx = Path(tmpdir, "dwdx.cpp").read_text()
assert "dobservable_x1_dx1 = " in dwdx


def test_hardcode_parameters(simple_sbml_model):
"""Test model generation works for model without observables"""
sbml_doc, sbml_model = simple_sbml_model
sbml_importer = SbmlImporter(sbml_source=sbml_model, from_file=False)
r = sbml_model.createRateRule()
r.setVariable("S1")
r.setFormula("p1")
assert sbml_model.getParameter("p1").getValue() != 0

ode_model = sbml_importer._build_ode_model()
assert str(ode_model.parameters()) == "[p1]"
assert ode_model.differential_states()[0].get_dt().name == "p1"

ode_model = sbml_importer._build_ode_model(
constant_parameters=[],
hardcode_symbols=["p1"],
)
assert str(ode_model.parameters()) == "[]"
assert (
ode_model.differential_states()[0].get_dt()
== sbml_model.getParameter("p1").getValue()
)

with pytest.raises(ValueError):
sbml_importer._build_ode_model(
# mutually exclusive
constant_parameters=["p1"],
hardcode_symbols=["p1"],
)