Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1066 add numpy function sqrt, sin, cos and exp to convert_to_casadi #1067

Merged
merged 7 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

## Bug fixes

- Allowed for pybamm functions exp, sin, cos, sqrt to be used in expression trees that
are converted to casadi format ([#1067](https://github.com/pybamm-team/PyBaMM/pull/1067)
- Fix a bug where variables that depend on y and z were transposed in `QuickPlot` ([#1055](https://github.com/pybamm-team/PyBaMM/pull/1055))

## Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm

pybamm.set_logging_level("INFO")
# pybamm.set_logging_level("INFO")

# load models
models = [
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def entries_string(self, value):
if issparse(entries):
self._entries_string = str(entries.__dict__)
else:
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def entries_string(self, value):
self._entries_string = value
else:
entries = self.data
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
Expand Down
22 changes: 22 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.mmax(*converted_children)
elif symbol.function == np.abs:
return casadi.fabs(*converted_children)
elif symbol.function == np.sqrt:
return casadi.sqrt(*converted_children)
elif symbol.function == np.sin:
return casadi.sin(*converted_children)
elif symbol.function == np.arcsinh:
return casadi.arcsinh(*converted_children)
elif symbol.function == np.arccosh:
return casadi.arccosh(*converted_children)
elif symbol.function == np.tanh:
return casadi.tanh(*converted_children)
elif symbol.function == np.cosh:
return casadi.cosh(*converted_children)
elif symbol.function == np.sinh:
return casadi.sinh(*converted_children)
elif symbol.function == np.cos:
return casadi.cos(*converted_children)
elif symbol.function == np.exp:
return casadi.exp(*converted_children)
elif symbol.function == np.log:
return casadi.log(*converted_children)
elif symbol.function == np.sign:
return casadi.sign(*converted_children)
elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)):
return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)(
*converted_children
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def shape(self):
# Default behaviour is to try to evaluate the object directly
# Try with some large y, to avoid having to unpack (slow)
try:
y = np.linspace(0.1, 0.9, int(1e4))
y = np.nan * np.ones((1000, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
# If that fails, fall back to calculating how big y should really be
except ValueError:
Expand All @@ -753,7 +753,7 @@ def shape(self):
len(x._evaluation_array) for x in state_vectors_in_node
)
# Pick a y that won't cause RuntimeWarnings
y = np.linspace(0.1, 0.9, min_y_size)
y = np.nan * np.ones((min_y_size, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")

# Return shape of evaluated object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def test_convert_scalar_symbols(self):
self.assertEqual(abs(c).to_casadi(), casadi.MX(1))

# function
def sin(x):
return np.sin(x)
def square_plus_one(x):
return x ** 2 + 1

f = pybamm.Function(sin, b)
self.assertEqual(f.to_casadi(), casadi.MX(np.sin(1)))
f = pybamm.Function(square_plus_one, b)
self.assertEqual(f.to_casadi(), 2)

def myfunction(x, y):
return x + y
Expand Down Expand Up @@ -95,6 +95,12 @@ def test_special_functions(self):
self.assert_casadi_equal(
pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True
)
for np_fun in [np.sqrt, np.tanh, np.cosh, np.sinh,
np.exp, np.log, np.sign, np.sin, np.cos,
np.arccosh, np.arcsinh]:
self.assert_casadi_equal(
pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True
)

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]
Expand Down