Skip to content

Commit

Permalink
Merge pull request #2141 from pybamm-team/issue-2115-bool
Browse files Browse the repository at this point in the history
#2115 Boolean operator not implemented
  • Loading branch information
valentinsulzer authored Jul 6, 2022
2 parents 212c306 + fecbd28 commit 6a51492
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

## Bug fixes

- Raise explicit `NotImplementedError` if trying to call `bool()` on a pybamm Symbol (e.g. in an if statement condition) ([#2141](https://github.com/pybamm-team/PyBaMM/pull/2141))
- Fixed bug causing cut-off voltage to change after setting up a simulation with a model ([#2138](https://github.com/pybamm-team/PyBaMM/pull/2138))
- A single solution cycle can now be used as a starting solution for a simulation ([#2138](https://github.com/pybamm-team/PyBaMM/pull/2138))

Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def diff(self, variable):
) * child.diff(variable)

# remove None entries
partial_derivatives = list(filter(None, partial_derivatives))
partial_derivatives = [x for x in partial_derivatives if x is not None]

derivative = sum(partial_derivatives)
if derivative == 0:
Expand Down
3 changes: 3 additions & 0 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,9 @@ def __mod__(self, other):
"""return an :class:`Modulo` object."""
return pybamm.simplify_if_constant(pybamm.Modulo(self, other))

def __bool__(self):
raise NotImplementedError("Boolean operator not defined for Symbols.")

def diff(self, variable):
"""
Differentiate a symbol with respect to a variable. For any symbol that can be
Expand Down
8 changes: 4 additions & 4 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def info(self, symbol_name):
div = "-----------------------------------------"
symbol = find_symbol_in_model(self, symbol_name)

if not symbol:
if symbol is None:
return None

print(div)
Expand Down Expand Up @@ -1034,22 +1034,22 @@ def find_symbol_in_tree(tree, name):
elif len(tree.children) > 0:
for child in tree.children:
child_return = find_symbol_in_tree(child, name)
if child_return:
if child_return is not None:
return child_return


def find_symbol_in_dict(dic, name):
for tree in dic.values():
tree_return = find_symbol_in_tree(tree, name)
if tree_return:
if tree_return is not None:
return tree_return


def find_symbol_in_model(model, name):
dics = [model.rhs, model.algebraic, model.variables]
for dic in dics:
dic_return = find_symbol_in_dict(dic, name)
if dic_return:
if dic_return is not None:
return dic_return


Expand Down
12 changes: 8 additions & 4 deletions pybamm/models/submodels/particle/base_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def _get_standard_concentration_variables(

# Get surface concentration if not provided as fundamental variable to
# solve for
c_s_surf = c_s_surf or pybamm.surf(c_s)
if c_s_surf is None:
c_s_surf = pybamm.surf(c_s)
c_s_surf_av = pybamm.x_average(c_s_surf)

c_scale = self.domain_param.c_max

# Get average concentration(s) if not provided as fundamental variable to
# solve for
c_s_xav = c_s_xav or pybamm.x_average(c_s)
c_s_rav = c_s_rav or pybamm.r_average(c_s)
c_s_av = c_s_av or pybamm.r_average(c_s_xav)
if c_s_xav is None:
c_s_xav = pybamm.x_average(c_s)
if c_s_rav is None:
c_s_rav = pybamm.r_average(c_s)
if c_s_av is None:
c_s_av = pybamm.r_average(c_s_xav)

variables = {
self.domain + " particle concentration": c_s,
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,15 @@ def test_is_matrix_zero(self):
self.assertFalse(pybamm.is_matrix_zero(b))
self.assertFalse(pybamm.is_matrix_zero(c))

def test_bool(self):
a = pybamm.Symbol("a")
with self.assertRaisesRegex(NotImplementedError, "Boolean"):
bool(a)
# if statement calls Boolean
with self.assertRaisesRegex(NotImplementedError, "Boolean"):
if a > 1:
print("a is greater than 1")


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 6a51492

Please sign in to comment.