From 469c8cb68bd08a4580d25bee5367d40bdb68c8a1 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 6 Dec 2021 19:06:12 +0100 Subject: [PATCH] Refactor amici.ode_export, add AmiciCxxCodePrinter Move custom codeprinting to a separate code printer class and reuse the same instance for all printing. Avoids plenty of reinstantiations of the code printer and keeps all settings in one place. --- python/amici/__init__.py | 3 +- python/amici/cxxcodeprinter.py | 214 ++++++++++++++ python/amici/ode_export.py | 413 ++++++--------------------- python/sdist/amici/cxxcodeprinter.py | 1 + python/tests/test_ode_export.py | 16 +- 5 files changed, 318 insertions(+), 329 deletions(-) create mode 100644 python/amici/cxxcodeprinter.py create mode 120000 python/sdist/amici/cxxcodeprinter.py diff --git a/python/amici/__init__.py b/python/amici/__init__.py index af4b7376d0..de383e92e2 100644 --- a/python/amici/__init__.py +++ b/python/amici/__init__.py @@ -144,7 +144,8 @@ def _capture_cstdout(): from typing import Protocol class ModelModule(Protocol): - """Enable Python static type checking for AMICI-generated model modules""" + """Enable Python static type checking for AMICI-generated model + modules""" def getModel(self) -> amici.Model: pass except ImportError: diff --git a/python/amici/cxxcodeprinter.py b/python/amici/cxxcodeprinter.py new file mode 100644 index 0000000000..5411010270 --- /dev/null +++ b/python/amici/cxxcodeprinter.py @@ -0,0 +1,214 @@ +"""C++ code generation""" +import re +from typing import List, Optional, Tuple, Dict + +import sympy as sp +from sympy.printing.cxx import CXX11CodePrinter + + +class AmiciCxxCodePrinter(CXX11CodePrinter): + """C++ code printer""" + + def __init__(self): + super().__init__() + + def doprint(self, expr: sp.Expr, assign_to: Optional[str] = None) -> str: + try: + code = super().doprint(expr, assign_to) + code = re.sub(r'(^|\W)M_PI(\W|$)', r'\1amici::pi\2', code) + + return code + except TypeError as e: + raise ValueError( + f'Encountered unsupported function in expression "{expr}": ' + f'{e}!' + ) + + def _get_sym_lines_array( + self, + equations: sp.Matrix, + variable: str, + indent_level: int + ) -> List[str]: + """ + Generate C++ code for assigning symbolic terms in symbols to C++ array + `variable`. + + :param equations: + vectors of symbolic expressions + + :param variable: + name of the C++ array to assign to + + :param indent_level: + indentation level (number of leading blanks) + + :return: + C++ code as list of lines + """ + return [ + ' ' * indent_level + f'{variable}[{index}] = ' + f'{self.doprint(math)};' + for index, math in enumerate(equations) + if math not in [0, 0.0] + ] + + def _get_sym_lines_symbols( + self, symbols: sp.Matrix, + equations: sp.Matrix, + variable: str, + indent_level: int + ) -> List[str]: + """ + Generate C++ code for where array elements are directly replaced with + their corresponding macro symbol + + :param symbols: + vectors of symbols that equations are assigned to + + :param equations: + vectors of expressions + + :param variable: + name of the C++ array to assign to, only used in comments + + :param indent_level: + indentation level (number of leading blanks) + + :return: + C++ code as list of lines + """ + return [ + f'{" " * indent_level}{sym} = {self.doprint(math)};' + f' // {variable}[{index}]'.replace('\n', + '\n' + ' ' * indent_level) + for index, (sym, math) in enumerate(zip(symbols, equations)) + if math not in [0, 0.0] + ] + + def csc_matrix( + self, + matrix: sp.Matrix, + rownames: List[sp.Symbol], + colnames: List[sp.Symbol], + identifier: Optional[int] = 0, + pattern_only: Optional[bool] = False + ) -> Tuple[ + List[int], List[int], sp.Matrix, List[str], sp.Matrix + ]: + """ + Generates the sparse symbolic identifiers, symbolic identifiers, + sparse matrix, column pointers and row values for a symbolic + variable + + :param matrix: + dense matrix to be sparsified + + :param rownames: + ids of the variable of which the derivative is computed (assuming + matrix is the jacobian) + + :param colnames: + ids of the variable with respect to which the derivative is computed + (assuming matrix is the jacobian) + + :param identifier: + additional identifier that gets appended to symbol names to + ensure their uniqueness in outer loops + + :param pattern_only: + flag for computing sparsity pattern without whole matrix + + :return: + symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, + sparse_matrix + + """ + idx = 0 + + nrows, ncols = matrix.shape + + if not pattern_only: + sparse_matrix = sp.zeros(nrows, ncols) + symbol_list = [] + sparse_list = [] + symbol_col_ptrs = [] + symbol_row_vals = [] + + for col in range(ncols): + symbol_col_ptrs.append(idx) + for row in range(nrows): + if matrix[row, col] == 0: + continue + + symbol_row_vals.append(row) + idx += 1 + symbol_name = f'd{self.doprint(rownames[row])}' \ + f'_d{self.doprint(colnames[col])}' + if identifier: + symbol_name += f'_{identifier}' + symbol_list.append(symbol_name) + if pattern_only: + continue + + sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True) + sparse_list.append(matrix[row, col]) + + if idx == 0: + symbol_col_ptrs = [] # avoid bad memory access for empty matrices + else: + symbol_col_ptrs.append(idx) + + if pattern_only: + sparse_matrix = None + else: + sparse_list = sp.Matrix(sparse_list) + + return symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, \ + sparse_matrix + + +def get_switch_statement(condition: str, cases: Dict[int, List[str]], + indentation_level: Optional[int] = 0, + indentation_step: Optional[str] = ' ' * 4): + """ + Generate code for switch statement + + :param condition: + Condition for switch + + :param cases: + Cases as dict with expressions as keys and statement as + list of strings + + :param indentation_level: + indentation level + + :param indentation_step: + indentation whitespace per level + + :return: + Code for switch expression as list of strings + + """ + lines = [] + + if not cases: + return lines + + for expression, statements in cases.items(): + if statements: + lines.append((indentation_level + 1) * indentation_step + + f'case {expression}:') + for statement in statements: + lines.append((indentation_level + 2) * indentation_step + + statement) + lines.append((indentation_level + 2) * indentation_step + 'break;') + + if lines: + lines.insert(0, indentation_level * indentation_step + + f'switch({condition}) {{') + lines.append(indentation_level * indentation_step + '}') + + return lines + diff --git a/python/amici/ode_export.py b/python/amici/ode_export.py index f1bcd0a4c3..526f659d73 100644 --- a/python/amici/ode_export.py +++ b/python/amici/ode_export.py @@ -31,13 +31,11 @@ Set, Any ) from string import Template -from sympy.printing import cxxcode -from sympy.printing.cxx import _CXXCodePrinterBase from sympy.matrices.immutable import ImmutableDenseMatrix from sympy.matrices.dense import MutableDenseMatrix from sympy.logic.boolalg import BooleanAtom from itertools import chain - +from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement from . import ( amiciSwigPath, amiciSrcPath, amiciModulePath, __version__, __commit__, @@ -955,6 +953,9 @@ class ODEModel: :ivar _has_quadratic_nllh: whether all observables have a gaussian noise model, i.e. whether res and FIM make sense. + + :ivar _code_printer: + Code printer to generate C++ code """ def __init__(self, verbose: Optional[Union[bool, int]] = False, @@ -1041,10 +1042,16 @@ def __init__(self, verbose: Optional[Union[bool, int]] = False, self._has_quadratic_nllh: bool = True set_log_level(logger, verbose) + self._code_printer = AmiciCxxCodePrinter() + for fun in CUSTOM_FUNCTIONS: + self._code_printer.known_functions[fun['sympy']] = fun['c++'] + @log_execution_time('importing SbmlImporter', logger) - def import_from_sbml_importer(self, - si: 'sbml_import.SbmlImporter', - compute_cls: Optional[bool] = True) -> None: + def import_from_sbml_importer( + self, + si: 'sbml_import.SbmlImporter', + compute_cls: Optional[bool] = True + ) -> None: """ Imports a model specification from a :class:`amici.sbml_import.SbmlImporter` @@ -1052,6 +1059,8 @@ def import_from_sbml_importer(self, :param si: imported SBML model + :param compute_cls: + whether to compute conservation laws """ # get symbolic expression from SBML importers @@ -1263,7 +1272,6 @@ def add_conservation_law(self, self._states[ix].set_conservation_law(state_expr) - def get_observable_transformations(self) -> List[ObservableTransformation]: """ List of observable transformations @@ -1273,7 +1281,6 @@ def get_observable_transformations(self) -> List[ObservableTransformation]: """ return [obs.trafo for obs in self._observables] - def num_states_rdata(self) -> int: """ Number of states. @@ -1311,7 +1318,7 @@ def num_state_reinits(self) -> int: """ reinit_states = self.eq('x0_fixedParameters') solver_states = self.eq('x_solver') - return sum([1 for ix in reinit_states if ix in solver_states]) + return sum(ix in solver_states for ix in reinit_states) def num_obs(self) -> int: """ @@ -1697,10 +1704,9 @@ def _generate_sparse_symbol(self, name: str) -> None: self._syms[name] = [] for iy in range(self.num_obs()): symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, \ - sparse_matrix = csc_matrix(matrix[iy, :], - rownames=rownames, - colnames=colnames, - identifier=iy) + sparse_matrix = self._code_printer.csc_matrix( + matrix[iy, :], rownames=rownames, colnames=colnames, + identifier=iy) self._colptrs[name].append(symbol_col_ptrs) self._rowvals[name].append(symbol_row_vals) self._sparseeqs[name].append(sparse_list) @@ -1708,7 +1714,7 @@ def _generate_sparse_symbol(self, name: str) -> None: self._syms[name].append(sparse_matrix) else: symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, \ - sparse_matrix = csc_matrix( + sparse_matrix = self._code_printer.csc_matrix( matrix, rownames=rownames, colnames=colnames, pattern_only=name in nobody_functions ) @@ -1802,7 +1808,7 @@ def _compute_equation(self, name: str) -> None: self._x0_fixedParameters_idx = [ ix for ix, eq in enumerate(self.eq('x0')) - if any([sym in eq.free_symbols for sym in k]) + if any(sym in eq.free_symbols for sym in k) ] eq = self.eq('x0') self._eqs[name] = sp.Matrix([eq[ix] for ix in @@ -1968,25 +1974,24 @@ def _derivative(self, eq: str, var: str, name: str = None) -> None: if not name: name = f'd{eq}d{var}' - # automatically detect chainrule - chainvars = [] ignore_chainrule = { ('xdot', 'p'): 'w', # has generic implementation in c++ code ('xdot', 'x'): 'w', # has generic implementation in c++ code ('w', 'w'): 'tcl', # dtcldw = 0 ('w', 'x'): 'tcl', # dtcldx = 0 } - for cv in ['w', 'tcl']: - if var_in_function_signature(eq, cv) \ - and cv not in self._lock_total_derivative \ - and var is not cv \ - and min(self.sym(cv).shape) \ - and ( - (eq, var) not in ignore_chainrule - or ignore_chainrule[(eq, var)] != cv - ): - chainvars.append(cv) - + # automatically detect chainrule + chainvars = [ + cv for cv in ['w', 'tcl'] + if var_in_function_signature(eq, cv) + and cv not in self._lock_total_derivative + and var is not cv + and min(self.sym(cv).shape) + and ( + (eq, var) not in ignore_chainrule + or ignore_chainrule[(eq, var)] != cv + ) + ] if len(chainvars): self._lock_total_derivative += chainvars self._total_derivative(name, eq, chainvars, var) @@ -1998,11 +2003,7 @@ def _derivative(self, eq: str, var: str, name: str = None) -> None: needs_stripped_symbols = eq == 'xdot' and var != 'x' # partial derivative - if eq == 'Jy': - sym_eq = self.eq(eq).transpose() - else: - sym_eq = self.eq(eq) - + sym_eq = self.eq(eq).transpose() if eq == 'Jy' else self.eq(eq) if pysb is not None and needs_stripped_symbols: needs_stripped_symbols = not any( isinstance(sym, pysb.Component) @@ -2162,18 +2163,14 @@ def _multiplication(self, name: str, x: str, y: str, if sign not in [-1, 1]: raise TypeError(f'sign must be +1 or -1, was {sign}') - variables = dict() - for varname in [x, y]: - if var_in_function_signature(name, varname): - variables[varname] = self.sym(varname) - else: - variables[varname] = self.eq(varname) - - if transpose_x: - xx = variables[x].transpose() - else: - xx = variables[x] + variables = { + varname: self.sym(varname) + if var_in_function_signature(name, varname) + else self.eq(varname) + for varname in [x, y] + } + xx = variables[x].transpose() if transpose_x else variables[x] yy = variables[y] self._eqs[name] = sign * smart_multiply(xx, yy) @@ -2194,12 +2191,10 @@ def _equation_from_component(self, name: str, component: str) -> None: ) def get_conservation_laws(self) -> List[Tuple[sp.Symbol, sp.Basic]]: - """ Returns a list of states with conservation law set - + """Returns a list of states with conservation law set :return: list of state identifiers - """ return [ (state.get_id(), state._conservation_law) @@ -2259,10 +2254,10 @@ def state_has_fixed_parameter_initial_condition(self, ix: int) -> bool: ic = self._states[ix].get_val() if not isinstance(ic, sp.Basic): return False - return any([ + return any( fp in [c.get_id() for c in self._constants] for fp in ic.free_symbols - ]) + ) def state_has_conservation_law(self, ix: int) -> bool: """ @@ -2327,11 +2322,10 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool: # Check if any time-dependent states are in the expression. state_syms = [str(sym) for sym in self._states] - for state in expr_syms.intersection(state_syms): - if not self.state_is_constant(state_syms.index(state)): - return True - - return False + return any( + not self.state_is_constant(state_syms.index(state)) + for state in expr_syms.intersection(state_syms) + ) def _get_unique_root( self, @@ -2449,89 +2443,6 @@ def _process_heavisides( return dxdt -def _print_with_exception(math: sp.Expr) -> str: - """ - Generate C++ code for a symbolic expression - - :param math: - symbolic expression - - :return: - C++ code for the specified expression - """ - # get list of custom replacements - user_functions = {fun['sympy']: fun['c++'] for fun in CUSTOM_FUNCTIONS} - - try: - ret = cxxcode(math, standard='c++11', user_functions=user_functions) - ret = re.sub(r'(^|\W)M_PI(\W|$)', r'\1amici::pi\2', ret) - return ret - except TypeError as e: - raise ValueError( - f'Encountered unsupported function in expression "{math}": ' - f'{e}!' - ) - - -def _get_sym_lines_array(equations: sp.Matrix, - variable: str, - indent_level: int) -> List[str]: - """ - Generate C++ code for assigning symbolic terms in symbols to C++ array - `variable`. - - :param equations: - vectors of symbolic expressions - - :param variable: - name of the C++ array to assign to - - :param indent_level: - indentation level (number of leading blanks) - - :return: - C++ code as list of lines - - """ - - return [' ' * indent_level + f'{variable}[{index}] = ' - f'{_print_with_exception(math)};' - for index, math in enumerate(equations) - if not (math == 0 or math == 0.0)] - - -def _get_sym_lines_symbols(symbols: sp.Matrix, - equations: sp.Matrix, - variable: str, - indent_level: int) -> List[str]: - """ - Generate C++ code for where array elements are directly replaced with - their corresponding macro symbol - - :param symbols: - vectors of symbols that equations are assigned to - - :param equations: - vectors of expressions - - :param variable: - name of the C++ array to assign to, only used in comments - - :param indent_level: - indentation level (number of leading blanks) - - :return: - C++ code as list of lines - - """ - - return [f'{" " * indent_level}{sym} = {_print_with_exception(math)};' - f' // {variable}[{index}]'.replace('\n', - '\n' + ' ' * indent_level) - for index, (sym, math) in enumerate(zip(symbols, equations)) - if not (math == 0 or math == 0.0)] - - class ODEExporter: """ The ODEExporter class generates AMICI C++ files for ODE model as @@ -2882,7 +2793,7 @@ def _write_function_file(self, function: str) -> None: # Unfortunately we cannot check for `self.functions[sym]['body']` # here since it may not have been generated yet. for match in re.findall( - fr'const (realtype|double) \*([\w]+)[0]*[,\)]+', signature + r'const (realtype|double) \*([\w]+)[0]*[,\)]+', signature ): sym = match[1] if sym not in self.model.sym_names(): @@ -3071,8 +2982,10 @@ def _get_function_body(self, # dJydy is a list return lines - if not self.allow_reinit_fixpar_initcond \ - and function in ['sx0_fixedParameters', 'x0_fixedParameters']: + if not self.allow_reinit_fixpar_initcond and function in { + 'sx0_fixedParameters', + 'x0_fixedParameters', + }: return lines if function == 'sx0_fixedParameters': @@ -3100,7 +3013,7 @@ def _get_function_body(self, " sx0_fixedParameters[idx] = 0.0;", " }"]) - cases = dict() + cases = {} for ipar in range(self.model.num_par()): expressions = [] for index, formula in zip( @@ -3114,7 +3027,7 @@ def _get_function_body(self, f'reinitialization_state_idxs.cend(), {index}) != ' 'reinitialization_state_idxs.cend())', f' {function}[{index}] = ' - f'{_print_with_exception(formula)};' + f'{self.model._code_printer.doprint(formula)};' ]) cases[ipar] = expressions lines.extend(get_switch_statement('ip', cases, 1)) @@ -3129,12 +3042,15 @@ def _get_function_body(self, f'reinitialization_state_idxs.cend(), {index}) != ' 'reinitialization_state_idxs.cend())\n ' f'{function}[{index}] = ' - f'{_print_with_exception(formula)};') + f'{self.model._code_printer.doprint(formula)};') elif function in event_functions: - cases = {ie: _get_sym_lines_array(equations[ie], function, 0) - for ie in range(self.model.num_events()) - if not smart_is_zero_matrix(equations[ie])} + cases = { + ie: self.model._code_printer._get_sym_lines_array( + equations[ie], function, 0) + for ie in range(self.model.num_events()) + if not smart_is_zero_matrix(equations[ie]) + } lines.extend(get_switch_statement('ie', cases, 1)) elif function in event_sensi_functions: @@ -3142,31 +3058,37 @@ def _get_function_body(self, for ie, inner_equations in enumerate(equations): inner_lines = [] inner_cases = { - ipar: _get_sym_lines_array(inner_equations[:, ipar], - function, 0) + ipar: self.model._code_printer._get_sym_lines_array( + inner_equations[:, ipar], function, 0) for ipar in range(self.model.num_par()) - if not smart_is_zero_matrix(inner_equations[:, ipar])} + if not smart_is_zero_matrix(inner_equations[:, ipar]) + } inner_lines.extend(get_switch_statement( 'ip', inner_cases, 0)) outer_cases[ie] = copy.copy(inner_lines) lines.extend(get_switch_statement('ie', outer_cases, 1)) elif function in sensi_functions: - cases = {ipar: _get_sym_lines_array(equations[:, ipar], function, - 0) - for ipar in range(self.model.num_par()) - if not smart_is_zero_matrix(equations[:, ipar])} + cases = { + ipar: self.model._code_printer._get_sym_lines_array( + equations[:, ipar], function, 0) + for ipar in range(self.model.num_par()) + if not smart_is_zero_matrix(equations[:, ipar]) + } lines.extend(get_switch_statement('ip', cases, 1)) elif function in multiobs_functions: if function == 'dJydy': - cases = {iobs: _get_sym_lines_array(equations[iobs], function, - 0) - for iobs in range(self.model.num_obs()) - if not smart_is_zero_matrix(equations[iobs])} + cases = { + iobs: self.model._code_printer._get_sym_lines_array( + equations[iobs], function, 0) + for iobs in range(self.model.num_obs()) + if not smart_is_zero_matrix(equations[iobs]) + } else: cases = { - iobs: _get_sym_lines_array(equations[:, iobs], function, 0) + iobs: self.model._code_printer._get_sym_lines_array( + equations[:, iobs], function, 0) for iobs in range(self.model.num_obs()) if not smart_is_zero_matrix(equations[:, iobs]) } @@ -3178,10 +3100,12 @@ def _get_function_body(self, symbols = self.model.sparsesym(function) else: symbols = self.model.sym(function, stripped=True) - lines += _get_sym_lines_symbols(symbols, equations, function, 4) + lines += self.model._code_printer._get_sym_lines_symbols( + symbols, equations, function, 4) else: - lines += _get_sym_lines_array(equations, function, 4) + lines += self.model._code_printer._get_sym_lines_array( + equations, function, 4) return [line for line in lines if line] @@ -3247,9 +3171,10 @@ def _write_model_header_cpp(self) -> None: 'NK': str(self.model.num_const()), 'O2MODE': 'amici::SecondOrderMode::none', # using cxxcode ensures proper handling of nan/inf - 'PARAMETERS': _print_with_exception(self.model.val('p'))[1:-1], - 'FIXED_PARAMETERS': _print_with_exception(self.model.val('k'))[ - 1:-1], + 'PARAMETERS': self.model._code_printer.doprint( + self.model.val('p'))[1:-1], + 'FIXED_PARAMETERS': self.model._code_printer.doprint( + self.model.val('k'))[1:-1], 'PARAMETER_NAMES_INITIALIZER_LIST': self._get_symbol_name_initializer_list('p'), 'STATE_NAMES_INITIALIZER_LIST': @@ -3455,7 +3380,7 @@ def set_paths(self, output_dir: Optional[str] = None) -> None: relative or absolute path where the generated model code is to be placed. If ``None``, this will default to `amici-{self.model_name}` in the current working directory. - will be created if does not exists. + will be created if it does not exist. """ if output_dir is None: @@ -3487,7 +3412,7 @@ class TemplateAmici(Template): """ Template format used in AMICI (see string.template for more details). - :ivar delimiter: + :cvar delimiter: delimiter that identifies template variables """ @@ -3510,7 +3435,6 @@ def apply_template(source_file: str, :param template_data: template keywords to substitute (key is template variable without :attr:`TemplateAmici.delimiter`) - """ with open(source_file) as filein: src = TemplateAmici(filein.read()) @@ -3528,7 +3452,6 @@ def strip_pysb(symbol: sp.Basic) -> sp.Basic: :return: stripped expression - """ # strip pysb type and transform into a flat sympy.Symbol. # this ensures that the pysb type specific __repr__ is used when converting @@ -3551,7 +3474,6 @@ def get_function_extern_declaration(fun: str, name: str) -> str: :return: c++ function definition string - """ return \ f'extern void {fun}_{name}{functions[fun]["signature"]};' @@ -3574,7 +3496,6 @@ def get_sunindex_extern_declaration(fun: str, name: str, :return: c++ function declaration string - """ index_arg = ', int index' if fun in multiobs_functions else '' return \ @@ -3598,7 +3519,6 @@ def get_model_override_implementation(fun: str, name: str, :return: c++ function implementation string - """ impl = 'virtual void f{fun}{signature} override {{' @@ -3638,7 +3558,6 @@ def get_sunindex_override_implementation(fun: str, name: str, :return: c++ function implementation string - """ index_arg = ', int index' if fun in multiobs_functions else '' index_arg_eval = ', index' if fun in multiobs_functions else '' @@ -3672,7 +3591,7 @@ def remove_typedefs(signature: str) -> str: string that can be used to construct function calls with the same variable names and ordering as in the function signature """ - # remove * pefix for pointers (pointer must always be removed before + # remove * prefix for pointers (pointer must always be removed before # values otherwise we will inadvertently dereference values, # same applies for const specifications) # @@ -3695,130 +3614,6 @@ def remove_typedefs(signature: str) -> str: return signature -def get_switch_statement(condition: str, cases: Dict[int, List[str]], - indentation_level: Optional[int] = 0, - indentation_step: Optional[str] = ' ' * 4): - """ - Generate code for switch statement - - :param condition: - Condition for switch - - :param cases: - Cases as dict with expressions as keys and statement as - list of strings - - :param indentation_level: - indentation level - - :param indentation_step: - indentation whitespace per level - - :return: - Code for switch expression as list of strings - - """ - lines = list() - - if not cases: - return lines - - for expression, statements in cases.items(): - if statements: - lines.append((indentation_level + 1) * indentation_step - + f'case {expression}:') - for statement in statements: - lines.append((indentation_level + 2) * indentation_step - + statement) - lines.append((indentation_level + 2) * indentation_step + 'break;') - - if lines: - lines.insert(0, indentation_level * indentation_step - + f'switch({condition}) {{') - lines.append(indentation_level * indentation_step + '}') - - return lines - - -def csc_matrix(matrix: sp.Matrix, - rownames: List[sp.Symbol], - colnames: List[sp.Symbol], - identifier: Optional[int] = 0, - pattern_only: Optional[bool] = False) -> Tuple[ - List[int], List[int], sp.Matrix, List[str], sp.Matrix -]: - """ - Generates the sparse symbolic identifiers, symbolic identifiers, - sparse matrix, column pointers and row values for a symbolic - variable - - :param matrix: - dense matrix to be sparsified - - :param rownames: - ids of the variable of which the derivative is computed (assuming - matrix is the jacobian) - - :param colnames: - ids of the variable with respect to which the derivative is computed - (assuming matrix is the jacobian) - - :param identifier: - additional identifier that gets appended to symbol names to - ensure their uniqueness in outer loops - - :param pattern_only: - flag for computing sparsity pattern without whole matrix - - :return: - symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, - sparse_matrix - - """ - idx = 0 - - nrows, ncols = matrix.shape - - if not pattern_only: - sparse_matrix = sp.zeros(nrows, ncols) - symbol_list = [] - sparse_list = [] - symbol_col_ptrs = [] - symbol_row_vals = [] - - for col in range(0, ncols): - symbol_col_ptrs.append(idx) - for row in range(0, nrows): - if matrix[row, col] == 0: - continue - - symbol_row_vals.append(row) - idx += 1 - symbol_name = f'd{_print_with_exception(rownames[row])}' \ - f'_d{_print_with_exception(colnames[col])}' - if identifier: - symbol_name += f'_{identifier}' - symbol_list.append(symbol_name) - if pattern_only: - continue - - sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True) - sparse_list.append(matrix[row, col]) - - if idx == 0: - symbol_col_ptrs = [] # avoid bad memory access for empty matrices - else: - symbol_col_ptrs.append(idx) - - if pattern_only: - sparse_matrix = None - else: - sparse_list = sp.Matrix(sparse_list) - - return symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, \ - sparse_matrix - - def is_valid_identifier(x: str) -> bool: """ Check whether `x` is a valid identifier for conditions, parameters, @@ -3925,7 +3720,6 @@ def _monkeypatched(obj: object, name: str, patch: Any): :param patch: patched value - """ pre_patched_value = getattr(obj, name) setattr(obj, name, patch) @@ -3937,7 +3731,7 @@ def _monkeypatched(obj: object, name: str, patch: Any): def _custom_pow_eval_derivative(self, s): """ - Custom Pow derivative that removes a removeable singularity for + Custom Pow derivative that removes a removable singularity for self.base == 0 and self.base.diff(s) == 0. This function is intended to be monkeypatched into sp.Pow._eval_derivative. @@ -3946,7 +3740,6 @@ def _custom_pow_eval_derivative(self, s): :param s: variable with respect to which the derivative will be computed - """ dbase = self.base.diff(s) dexp = self.exp.diff(s) @@ -3960,25 +3753,3 @@ def _custom_pow_eval_derivative(self, s): (self.base, sp.And(sp.Eq(self.base, 0), sp.Eq(dbase, 0))), (part2, True) ) - - -def _custom_print_max(self, expr): - """ - Custom Max printing function, see https://github.com/sympy/sympy/pull/20558 - """ - from sympy import Max - if len(expr.args) == 1: - return self._print(expr.args[0]) - return "%smax(%s, %s)" % (self._ns, self._print(expr.args[0]), - self._print(Max(*expr.args[1:]))) - - -def _custom_print_min(self, expr): - """ - Custom Min printing function, see https://github.com/sympy/sympy/pull/20558 - """ - from sympy import Min - if len(expr.args) == 1: - return self._print(expr.args[0]) - return "%smin(%s, %s)" % (self._ns, self._print(expr.args[0]), - self._print(Min(*expr.args[1:]))) diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py new file mode 120000 index 0000000000..4a7a4c588b --- /dev/null +++ b/python/sdist/amici/cxxcodeprinter.py @@ -0,0 +1 @@ +../../amici/cxxcodeprinter.py \ No newline at end of file diff --git a/python/tests/test_ode_export.py b/python/tests/test_ode_export.py index 4287a42fdf..83843348e8 100644 --- a/python/tests/test_ode_export.py +++ b/python/tests/test_ode_export.py @@ -1,15 +1,16 @@ """Miscellaneous AMICI Python interface tests""" -import amici import sympy as sp +from amici.cxxcodeprinter import AmiciCxxCodePrinter def test_csc_matrix(): """Test sparse CSC matrix creation""" + printer = AmiciCxxCodePrinter() matrix = sp.Matrix([[1, 0], [2, 3]]) symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, sparse_matrix \ - = amici.ode_export.csc_matrix(matrix, rownames=['a1', 'a2'], - colnames=['b1', 'b2']) + = printer.csc_matrix(matrix, rownames=['a1', 'a2'], + colnames=['b1', 'b2']) assert symbol_col_ptrs == [0, 2, 3] assert symbol_row_vals == [0, 1, 1] @@ -20,9 +21,10 @@ def test_csc_matrix(): def test_csc_matrix_empty(): """Test sparse CSC matrix creation for empty matrix""" + printer = AmiciCxxCodePrinter() matrix = sp.Matrix() symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, sparse_matrix \ - = amici.ode_export.csc_matrix(matrix, rownames=[], colnames=[]) + = printer.csc_matrix(matrix, rownames=[], colnames=[]) assert symbol_col_ptrs == [] assert symbol_row_vals == [] @@ -33,10 +35,10 @@ def test_csc_matrix_empty(): def test_csc_matrix_vector(): """Test sparse CSC matrix creation from matrix slice""" - + printer = AmiciCxxCodePrinter() matrix = sp.Matrix([[1, 0], [2, 3]]) symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, sparse_matrix \ - = amici.ode_export.csc_matrix( + = printer.csc_matrix( matrix[:, 0], colnames=[sp.Symbol('b')], rownames=[sp.Symbol('a1'), sp.Symbol('a2')] ) @@ -49,7 +51,7 @@ def test_csc_matrix_vector(): # Test continuation of numbering of symbols symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, sparse_matrix \ - = amici.ode_export.csc_matrix( + = printer.csc_matrix( matrix[:, 1], colnames=[sp.Symbol('b')], rownames=[sp.Symbol('a1'), sp.Symbol('a2')], identifier=1 )