diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 15194dd432..56e70c266d 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -150,9 +150,14 @@ class Index(UnaryOperator): The index (if int) or indices (if slice) to extract from the symbol name : str, optional The name of the symbol + check_size : bool, optional + Whether to check if the slice size exceeds the child size. Default is True. + This should always be True when creating a new symbol so that the appropriate + check is performed, but should be False for creating a new copy to avoid + unnecessarily repeating the check. """ - def __init__(self, child, index, name=None): + def __init__(self, child, index, name=None, check_size=True): self.index = index if index == -1: self.slice = slice(index, None) @@ -172,10 +177,11 @@ def __init__(self, child, index, name=None): else: raise TypeError("index must be integer or slice") - if self.slice in (slice(0, 1), slice(-1, None)): - pass - elif self.slice.stop > child.size: - raise ValueError("slice size exceeds child size") + if check_size: + if self.slice in (slice(0, 1), slice(-1, None)): + pass + elif self.slice.stop > child.size: + raise ValueError("slice size exceeds child size") super().__init__(name, child) @@ -217,7 +223,7 @@ def _unary_evaluate(self, child): def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ - return self.__class__(child, self.index) + return self.__class__(child, self.index, check_size=False) def evaluate_for_shape(self): return self._unary_evaluate(self.children[0].evaluate_for_shape())