Skip to content

Commit

Permalink
Merge pull request #2378 from pybamm-team/issue-2365-faster-tree-search
Browse files Browse the repository at this point in the history
#2365 replace tree search
  • Loading branch information
valentinsulzer authored Oct 21, 2022
2 parents ead95e0 + 1f40974 commit 6928516
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 43 deletions.
1 change: 0 additions & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
rmse,
load,
is_constant_and_can_evaluate,
tree_search,
)
from .util import (
get_parameters_filepath,
Expand Down
59 changes: 29 additions & 30 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def process_model(
model,
inplace=True,
check_model=True,
check_for_independent_variables=True,
remove_independent_variables_from_rhs=True,
):
"""Discretise a model.
Currently inplace, could be changed to return a new model.
Expand All @@ -118,7 +118,7 @@ def process_model(
option to False. When developing, testing or debugging it is recommended
to leave this option as True as it may help to identify any errors.
Default is True.
check_for_independent_variables : bool, optional
remove_independent_variables_from_rhs : bool, optional
If True, model checks to see whether any variables from the RHS are used
in any other equation. If a variable meets all of the following criteria
(not used anywhere in the model, len(rhs)>1), then the variable
Expand Down Expand Up @@ -162,8 +162,8 @@ def process_model(
# set variables (we require the full variable not just id)

# Search Equations for Independence
if check_for_independent_variables:
model = self.check_for_independent_variables(model)
if remove_independent_variables_from_rhs:
model = self.remove_independent_variables_from_rhs(model)
variables = list(model.rhs.keys()) + list(model.algebraic.keys())
# Find those RHS's that are constant
if self.spatial_methods == {} and any(var.domain != [] for var in variables):
Expand Down Expand Up @@ -1158,43 +1158,42 @@ def check_variables(self, model):
)
)

def search_for_independent_var(self, model, var):
def is_variable_independent(self, var, all_vars_in_eqns):
pybamm.logger.verbose("Removing independent blocks.")
boundary_variables = list(model.boundary_conditions.keys())
boundary_variable_keys = []
for condition in boundary_variables:
keys_for_condition = list(model.boundary_conditions[condition].keys())
boundary_variable_keys.append(keys_for_condition)
rhs_variables = list(model.rhs.keys())
algebraic_variables = list(model.algebraic.keys())
this_var_list = []
if not isinstance(var, pybamm.Variable):
return model, False
for tree in rhs_variables:
pybamm.tree_search(model.rhs[tree], var, this_var_list)
for tree in algebraic_variables:
pybamm.tree_search(model.algebraic[tree], var, this_var_list)
for (keys, tree) in zip(boundary_variable_keys, boundary_variables):
for key in keys:
pybamm.tree_search(
model.boundary_conditions[tree][key][0], var, this_var_list
)
for name in model.variables.keys():
for rhs_child in model.variables[name].children:
pybamm.tree_search(rhs_child, var, this_var_list)
this_var_is_independent = not any(this_var_list)
return False

this_var_is_independent = not (var.name in all_vars_in_eqns)
not_in_y_slices = not (var in list(self.y_slices.keys()))
not_in_discretised = not (var in list(self._discretised_symbols.keys()))
is_0D = len(var.domain) == 0
this_var_is_independent = (
this_var_is_independent and not_in_y_slices and not_in_discretised and is_0D
)
return model, this_var_is_independent
return this_var_is_independent

def check_for_independent_variables(self, model):
def remove_independent_variables_from_rhs(self, model):
rhs_vars_to_search_over = list(model.rhs.keys())
unpacker = pybamm.SymbolUnpacker(pybamm.Variable)
eqns_to_check = (
list(model.rhs.values())
+ list(model.algebraic.values())
+ [
x[side][0]
for x in model.boundary_conditions.values()
for side in x.keys()
]
# only check children of variables, this will skip the variable itself
# and catch any other cases
+ [child for var in model.variables.values() for child in var.children]
)
all_vars_in_eqns = unpacker.unpack_list_of_symbols(eqns_to_check)
all_vars_in_eqns = [var.name for var in all_vars_in_eqns]

for var in rhs_vars_to_search_over:
model, this_var_is_independent = self.search_for_independent_var(model, var)
this_var_is_independent = self.is_variable_independent(
var, all_vars_in_eqns
)
if this_var_is_independent:
if len(model.rhs) != 1:
pybamm.logger.info("removing variable {} from rhs".format(var))
Expand Down
11 changes: 0 additions & 11 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,6 @@
JAXLIB_VERSION = "0.1.70"


def tree_search(tree, item, solutions):
for child in tree.children:
tree_search(child, item, solutions)
if (child == item) or (child.name == item.name):
solutions.append(True)
else:
solutions.append(False)
solutions.append((tree == item) or (tree.name == item.name))
return None


def root_dir():
"""return the root directory of the PyBaMM install directory"""
return str(pathlib.Path(pybamm.__path__[0]).parent)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_input_params(self):
model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1}

disc = pybamm.Discretisation()
disc.process_model(model, check_for_independent_variables=False)
disc.process_model(model, remove_independent_variables_from_rhs=False)

solver = pybamm.IDAKLUSolver(root_method=root_method)

Expand Down

0 comments on commit 6928516

Please sign in to comment.