Skip to content

Commit

Permalink
Fix substitution of expressions into event roots (partially reverts #…
Browse files Browse the repository at this point in the history
  • Loading branch information
dilpath authored Apr 21, 2022
1 parent b951f39 commit fd0db58
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions python/amici/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def parse_events(self) -> None:
expr.set_val(self._process_heavisides(expr.get_val(), roots))

# remove all possible Heavisides from roots, which may arise from
# the substitution of `'w'` in `_get_unique_root`
# the substitution of `'w'` in `_collect_heaviside_roots`
for root in roots:
root.set_val(self._process_heavisides(root.get_val(), roots))

Expand Down Expand Up @@ -2066,15 +2066,6 @@ def _get_unique_root(
unique identifier for root, or ``None`` if the root is not
time-dependent
"""
# substitute 'w' expressions into root expressions now, to avoid
# rewriting '{model_name}_root.cpp' and '{model_name}_stau.cpp' headers
# to include 'w.h'
w_sorted = toposort_symbols(dict(zip(
[expr.get_id() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
)))
root_found = root_found.subs(w_sorted)

if not self._expr_is_time_dependent(root_found):
return None

Expand Down Expand Up @@ -2115,6 +2106,18 @@ def _collect_heaviside_roots(
elif arg.has(sp.Heaviside):
root_funs.extend(self._collect_heaviside_roots(arg.args))

# substitute 'w' expressions into root expressions now, to avoid
# rewriting '{model_name}_root.cpp' and '{model_name}_stau.cpp' headers
# to include 'w.h'
w_sorted = toposort_symbols(dict(zip(
[expr.get_id() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
)))
root_funs = [
r.subs(w_sorted)
for r in root_funs
]

return root_funs

def _process_heavisides(
Expand Down

0 comments on commit fd0db58

Please sign in to comment.