From 06ccfe87bcfcafd92e02091655c2c074983f5239 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sat, 7 Mar 2020 08:27:39 +0000 Subject: [PATCH] #858 fixes for diff --- pybamm/__init__.py | 3 ++- pybamm/expression_tree/symbol.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index f2d3d14bc6..5e9790389c 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -108,6 +108,7 @@ def version(formatted=False): ) from .expression_tree.scalar import Scalar from .expression_tree.variable import Variable, ExternalVariable, VariableDot +from .expression_tree.variable import VariableBase from .expression_tree.independent_variable import ( IndependentVariable, Time, @@ -115,7 +116,7 @@ def version(formatted=False): ) from .expression_tree.independent_variable import t from .expression_tree.vector import Vector -from .expression_tree.state_vector import StateVector, StateVectorDot +from .expression_tree.state_vector import StateVectorBase, StateVector, StateVectorDot from .expression_tree.exceptions import ( DomainError, diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 79496fa9c5..eb863f0d88 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -485,6 +485,12 @@ def diff(self, variable): return pybamm.Scalar(1) elif any(variable.id == x.id for x in self.pre_order()): return self._diff(variable) + elif variable.id == pybamm.t.id and \ + any( + isinstance(x, (pybamm.VariableBase, pybamm.StateVectorBase)) + for x in self.pre_order() + ): + return self._diff(variable) else: return pybamm.Scalar(0) @@ -770,5 +776,3 @@ def test_shape(self): self.shape_for_testing except ValueError as e: raise pybamm.ShapeError("Cannot find shape (original error: {})".format(e)) - -