Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 329 deepcopy #403

Merged
merged 14 commits into from
May 21, 2019
108 changes: 40 additions & 68 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def process_symbol(self, symbol):
"""
pybamm.logger.debug("Discretise {!s}".format(symbol))

if symbol.domain != []:
spatial_method = self._spatial_methods[symbol.domain[0]]

if isinstance(symbol, pybamm.BinaryOperator):
# Pre-process children
left, right = symbol.children
Expand All @@ -363,94 +366,63 @@ def process_symbol(self, symbol):
if symbol.domain == []:
return symbol.__class__(disc_left, disc_right)
else:
return self._spatial_methods[symbol.domain[0]].process_binary_operators(
return spatial_method.process_binary_operators(
symbol, left, right, disc_left, disc_right
)
elif isinstance(symbol, pybamm.UnaryOperator):
child = symbol.child
disc_child = self.process_symbol(child)
if child.domain != []:
child_spatial_method = self._spatial_methods[child.domain[0]]
if isinstance(symbol, pybamm.Gradient):
return child_spatial_method.gradient(child, disc_child, self._bcs)

elif isinstance(symbol, pybamm.Divergence):
return child_spatial_method.divergence(child, disc_child, self._bcs)

elif isinstance(symbol, pybamm.IndefiniteIntegral):
return child_spatial_method.indefinite_integral(
child.domain, child, disc_child
)

elif isinstance(symbol, pybamm.Gradient):
child = symbol.children[0]
discretised_child = self.process_symbol(child)
return self._spatial_methods[child.domain[0]].gradient(
child, discretised_child, self._bcs
)
elif isinstance(symbol, pybamm.Integral):
return child_spatial_method.integral(child.domain, child, disc_child)

elif isinstance(symbol, pybamm.Divergence):
child = symbol.children[0]
discretised_child = self.process_symbol(child)
return self._spatial_methods[child.domain[0]].divergence(
child, discretised_child, self._bcs
)
elif isinstance(symbol, pybamm.Broadcast):
# Broadcast new_child to the domain specified by symbol.domain
# Different discretisations may broadcast differently
if symbol.domain == []:
symbol = disc_child * pybamm.Vector(np.array([1]))
else:
symbol = spatial_method.broadcast(disc_child, symbol.domain)
return symbol

elif isinstance(symbol, pybamm.IndefiniteIntegral):
child = symbol.children[0]
discretised_child = self.process_symbol(child)
return self._spatial_methods[child.domain[0]].indefinite_integral(
child.domain, child, discretised_child
)
elif isinstance(symbol, pybamm.BoundaryOperator):
return child_spatial_method.boundary_value_or_flux(symbol, disc_child)

elif isinstance(symbol, pybamm.Integral):
child = symbol.children[0]
discretised_child = self.process_symbol(child)
return self._spatial_methods[child.domain[0]].integral(
child.domain, child, discretised_child
)

elif isinstance(symbol, pybamm.Broadcast):
# Process child first
new_child = self.process_symbol(symbol.children[0])
# Broadcast new_child to the domain specified by symbol.domain
# Different discretisations may broadcast differently
if symbol.domain == []:
symbol = new_child * pybamm.Vector(np.array([1]))
else:
symbol = self._spatial_methods[symbol.domain[0]].broadcast(
new_child, symbol.domain
)
return symbol

elif isinstance(symbol, pybamm.BoundaryOperator):
child = symbol.children[0]
discretised_child = self.process_symbol(child)
return self._spatial_methods[child.domain[0]].boundary_value_or_flux(
symbol, discretised_child
)

elif isinstance(symbol, pybamm.Function):
new_child = self.process_symbol(symbol.children[0])
return pybamm.Function(symbol.func, new_child)

elif isinstance(symbol, pybamm.UnaryOperator):
new_child = self.process_symbol(symbol.children[0])
return symbol.__class__(new_child)
return symbol._unary_new_copy(disc_child)

elif isinstance(symbol, pybamm.Variable):
return pybamm.StateVector(self._y_slices[symbol.id], domain=symbol.domain)

elif isinstance(symbol, pybamm.SpatialVariable):
return self._spatial_methods[symbol.domain[0]].spatial_variable(symbol)
return spatial_method.spatial_variable(symbol)

elif isinstance(symbol, pybamm.Concatenation):
new_children = [self.process_symbol(child) for child in symbol.children]
new_symbol = pybamm.DomainConcatenation(new_children, self.mesh)

return new_symbol

elif isinstance(symbol, pybamm.Scalar):
return pybamm.Scalar(symbol.value, symbol.name, symbol.domain)

elif isinstance(symbol, pybamm.Array):
return symbol.__class__(symbol.entries, symbol.name, symbol.domain)

elif isinstance(symbol, pybamm.StateVector):
return symbol.__class__(symbol.y_slice, symbol.name, symbol.domain)

elif isinstance(symbol, pybamm.Time):
return pybamm.Time()

else:
raise NotImplementedError(
"Cannot discretise symbol of type '{}'".format(type(symbol))
)
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot discretise symbol of type '{}'".format(type(symbol))
)

def concatenate(self, *symbols):
return pybamm.NumpyConcatenation(*symbols)
Expand Down
4 changes: 4 additions & 0 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def set_id(self):
(self.__class__, self.name, self.entries_string) + tuple(self.domain)
)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return self.__class__(self.entries, self.name, self.domain, self.entries_string)

def _base_evaluate(self, t=None, y=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
return self._entries
10 changes: 10 additions & 0 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def _binary_evaluate(self, left, right):
""" Perform binary operation on nodes 'left' and 'right'. """
raise NotImplementedError

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
# process children
new_left = self.left.new_copy()
new_right = self.right.new_copy()
# make new symbol, ensure domain remains the same
out = self.__class__(new_left, new_right)
out.domain = self.domain
return out

def evaluate(self, t=None, y=None, known_evals=None):
""" See :meth:`pybamm.Symbol.evaluate()`. """
if known_evals is not None:
Expand Down
7 changes: 6 additions & 1 deletion pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,9 @@ def __init__(self, child, domain, name=None):
def _unary_simplify(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """

return self.__class__(child, self.domain)
return Broadcast(child, self.domain)

def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """

return Broadcast(child, self.domain)
15 changes: 15 additions & 0 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ def evaluate(self, t=None, y=None, known_evals=None):
children_eval[idx] = child.evaluate(t, y)
return self._concatenation_evaluate(children_eval)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_children = [child.new_copy() for child in self.children]
return self._concatenation_new_copy(new_children)

def _concatenation_new_copy(self, children):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_symbol = self.__class__(*children)
return new_symbol

def _concatenation_simplify(self, children):
""" See :meth:`pybamm.Symbol.simplify()`. """
new_symbol = self.__class__(*children)
Expand Down Expand Up @@ -212,6 +222,11 @@ def jac(self, variable):
else:
return SparseStack(*[child.jac(variable) for child in children])

def _concatenation_new_copy(self, children):
""" See :meth:`pybamm.Symbol.new_copy()`. """
new_symbol = self.__class__(children, self.mesh, self)
return new_symbol

def _concatenation_simplify(self, children):
""" See :meth:`pybamm.Symbol.simplify()`. """
# Simplify Concatenation of StateVectors to a single StateVector
Expand Down
8 changes: 8 additions & 0 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class Time(IndependentVariable):
def __init__(self):
super().__init__("time")

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Time()

def _base_evaluate(self, t, y=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
if t is None:
Expand Down Expand Up @@ -76,6 +80,10 @@ def __init__(self, name, domain=[], coord_sys=None):
"domain cannot be particle if name is '{}'".format(name)
)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return SpatialVariable(self.name, self.domain, self.coord_sys)


# the independent variable time
t = Time()
11 changes: 7 additions & 4 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class Parameter(pybamm.Symbol):
def __init__(self, name, domain=[]):
super().__init__(name, domain=domain)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Parameter(self.name, self.domain)


class FunctionParameter(pybamm.UnaryOperator):
"""A node in the expression tree representing a function parameter
Expand Down Expand Up @@ -55,7 +59,6 @@ def diff(self, variable):
# when the parameters are set
return FunctionParameter(self.name, self.orphans[0], diff_variable=variable)

def _unary_simplify(self, child):
""" See :meth:`UnaryOperator._unary_simplify()`. """
child = self.child.simplify()
return pybamm.FunctionParameter(self.name, child, self.diff_variable)
def _unary_new_copy(self, child):
""" See :meth:`UnaryOperator._unary_new_copy()`. """
return FunctionParameter(self.name, child, diff_variable=self.diff_variable)
4 changes: 4 additions & 0 deletions pybamm/expression_tree/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def _base_evaluate(self, t=None, y=None):
def jac(self, variable):
""" See :meth:`pybamm.Symbol.jac()`. """
return pybamm.Scalar(0)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Scalar(self.value, self.name, self.domain)
43 changes: 16 additions & 27 deletions pybamm/expression_tree/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,17 @@ def flatten(
or isinstance(other_child, (pybamm.Scalar, pybamm.Vector))
):
left, right = child.orphans
if child == left_child and this_class == pybamm.Multiplication:
if (
child == left_child
and this_class == pybamm.Multiplication
and isinstance(other_child, pybamm.Vector)
):
# change (m @ v1) * v2 -> v2 * m @ v so can simplify correctly
# (#341)
numerator.append(other_child)
numerator_types.append(previous_class)
flatten(
this_class, child.__class__, left, right, in_numerator, False
this_class, child.__class__, left, right, in_numerator, True
)
break
if child == left_child:
Expand Down Expand Up @@ -554,41 +558,26 @@ def simplify(symbol):
# make new symbol, ensure domain remains the same
# _binary_simplify defined in derived classes for specific rules
new_symbol = symbol._binary_simplify(new_left, new_right)
new_symbol.domain = symbol.domain
return simplify_if_constant(new_symbol)

elif isinstance(symbol, pybamm.UnaryOperator):
new_child = simplify(symbol.child)
new_symbol = symbol._unary_simplify(new_child)
new_symbol.domain = symbol.domain
return simplify_if_constant(new_symbol)

elif isinstance(symbol, pybamm.Concatenation):
new_children = [child.simplify() for child in symbol.cached_children]
new_children = [simplify(child) for child in symbol.cached_children]
new_symbol = symbol._concatenation_simplify(new_children)

return simplify_if_constant(new_symbol)

# Other cases: return new variable to avoid tree internal corruption
elif isinstance(symbol, (pybamm.Parameter, pybamm.Variable)):
return symbol.__class__(symbol.name, symbol.domain)

elif isinstance(symbol, pybamm.StateVector):
return pybamm.StateVector(symbol.y_slice, symbol.name)

elif isinstance(symbol, pybamm.Scalar):
return pybamm.Scalar(symbol.value, symbol.name, symbol.domain)

elif isinstance(symbol, pybamm.Array):
return symbol.__class__(
symbol.entries, symbol.name, symbol.domain, symbol.entries_string
)

elif isinstance(symbol, pybamm.SpatialVariable):
return pybamm.SpatialVariable(symbol.name, symbol.domain, symbol.coord_sys)

elif isinstance(symbol, pybamm.Time):
return pybamm.Time()

else:
raise NotImplementedError(
"Cannot simplify symbol of type '{}'".format(type(symbol))
)
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot simplify symbol of type '{}'".format(type(symbol))
)
21 changes: 14 additions & 7 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,10 @@ def set_id(self):
@property
def orphans(self):
"""
Returning deepcopies of the children, with parents removed to avoid corrupting
Returning new copies of the children, with parents removed to avoid corrupting
the expression tree internal data
"""
orp = []
for child in self.children:
new_child = copy.deepcopy(child)
new_child.parent = None
orp.append(new_child)
return tuple(orp)
return tuple([child.new_copy() for child in self.children])

def render(self):
"""print out a visual representation of the tree (this node and its
Expand Down Expand Up @@ -476,6 +471,18 @@ def simplify(self):
""" Simplify the expression tree. See :meth:`pybamm.simplify()`. """
return pybamm.simplify(self)

def new_copy(self):
"""
Make a new copy of a symbol, to avoid Tree corruption errors while bypassing
copy.deepcopy(), which is slow.
"""
raise NotImplementedError(
"""method self.new_copy() not implemented
for symbol {!s} of type {}""".format(
self, type(self)
)
)

@property
def size(self):
"""
Expand Down
Loading