Skip to content

Commit

Permalink
#858 add y_dot arg to all evaluate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 6, 2020
1 parent 4c971a6 commit 96c6a41
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def new_copy(self):
self.entries_string,
)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
return self._entries
2 changes: 1 addition & 1 deletion pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Time()

def _base_evaluate(self, t, y=None, u=None):
def _base_evaluate(self, t, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
if t is None:
raise ValueError("t must be provided")
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _jac(self, variable):
""" See :meth:`pybamm.Symbol._jac()`. """
return pybamm.Scalar(0)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
# u should be a dictionary
# convert 'None' to empty dictionary for more informative error
if u is None:
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def set_id(self):
(self.__class__, self.name) + tuple(self.domain) + tuple(str(self._value))
)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
return self._value

Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(
auxiliary_domains=auxiliary_domains,
evaluation_array=evaluation_array)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
if y is None:
raise TypeError("StateVector cannot evaluate input 'y=None'")
Expand Down
17 changes: 11 additions & 6 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def _jac(self, variable):
"""
raise NotImplementedError

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
"""evaluate expression tree
will raise a ``NotImplementedError`` if this member function has not
Expand All @@ -520,7 +520,10 @@ def _base_evaluate(self, t=None, y=None, u=None):
time at which to evaluate (default None)
y : numpy.array, optional
array to evaluate when solving (default None)
array with state values to evaluate when solving (default None)
y_dot : numpy.array, optional
array with time derivatives of state values to evaluate when solving (default None)
"""
raise NotImplementedError(
Expand All @@ -530,7 +533,7 @@ def _base_evaluate(self, t=None, y=None, u=None):
)
)

def evaluate(self, t=None, y=None, u=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
"""Evaluate expression tree (wrapper to allow using dict of known values).
If the dict 'known_evals' is provided, the dict is searched for self.id; if
self.id is in the keys, return that value; otherwise, evaluate using
Expand All @@ -541,7 +544,9 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
t : float or numeric type, optional
time at which to evaluate (default None)
y : numpy.array, optional
array to evaluate when solving (default None)
array with state values to evaluate when solving (default None)
y_dot : numpy.array, optional
array with time derivatives of state values to evaluate when solving (default None)
u : dict, optional
dictionary of inputs to use when solving (default None)
known_evals : dict, optional
Expand All @@ -556,10 +561,10 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
"""
if known_evals is not None:
if self.id not in known_evals:
known_evals[self.id] = self._base_evaluate(t, y, u)
known_evals[self.id] = self._base_evaluate(t, y, y_dot, u)
return known_evals[self.id], known_evals
else:
return self._base_evaluate(t, y, u)
return self._base_evaluate(t, y, y_dot, u)

def evaluate_for_shape(self):
"""Evaluate expression tree to find its shape. For symbols that cannot be
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return np.nan * np.ones((self.size, 1))

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
# u should be a dictionary
# convert 'None' to empty dictionary for more informative error
if u is None:
Expand Down

0 comments on commit 96c6a41

Please sign in to comment.