Skip to content

Commit

Permalink
Merge pull request #1502 from priyanshuone6/sympy-binary
Browse files Browse the repository at this point in the history
Add to_equation in Binary and Scalar
  • Loading branch information
valentinsulzer authored Jun 11, 2021
2 parents f6cb07a + 3893835 commit 68bc122
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 119 deletions.
2 changes: 1 addition & 1 deletion docs/source/expression_tree/binary_operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Binary Operators
.. autoclass:: pybamm.Inner
:members:

.. autoclass:: pybamm.Heaviside
.. autoclass:: pybamm.expression_tree.binary_operators._Heaviside
:members:

.. autoclass:: pybamm.EqualHeaviside
Expand Down
26 changes: 17 additions & 9 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# NumpyArray class
#
import numpy as np
import sympy
from scipy.sparse import csr_matrix, issparse

import pybamm
from scipy.sparse import issparse, csr_matrix


class Array(pybamm.Symbol):
"""node in the expression tree that holds an tensor type variable
"""
Node in the expression tree that holds an tensor type variable
(e.g. :class:`numpy.array`)
Parameters
Expand Down Expand Up @@ -54,12 +57,12 @@ def entries(self):

@property
def ndim(self):
""" returns the number of dimensions of the tensor"""
"""returns the number of dimensions of the tensor."""
return self._entries.ndim

@property
def shape(self):
""" returns the number of entries along each dimension"""
"""returns the number of entries along each dimension."""
return self._entries.shape

@property
Expand All @@ -86,19 +89,19 @@ def entries_string(self, value):
self._entries_string = (entries.tobytes(),)

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
"""See :meth:`pybamm.Symbol.set_id()`."""
self._id = hash(
(self.__class__, self.name) + self.entries_string + tuple(self.domain)
)

def _jac(self, variable):
""" See :meth:`pybamm.Symbol._jac()`. """
"""See :meth:`pybamm.Symbol._jac()`."""
# Return zeros of correct size
jac = csr_matrix((self.size, variable.evaluation_array.count(True)))
return pybamm.Matrix(jac)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
"""See :meth:`pybamm.Symbol.new_copy()`."""
return self.__class__(
self.entries,
self.name,
Expand All @@ -108,13 +111,18 @@ def new_copy(self):
)

def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
return self._entries

def is_constant(self):
""" See :meth:`pybamm.Symbol.is_constant()`. """
"""See :meth:`pybamm.Symbol.is_constant()`."""
return True

def to_equation(self):
"""Returns the value returned by the node when evaluated."""
entries_list = self.entries.tolist()
return sympy.Array(entries_list)


def linspace(start, stop, num=50, **kwargs):
"""
Expand Down
Loading

0 comments on commit 68bc122

Please sign in to comment.