Skip to content

Commit

Permalink
#632 fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jun 14, 2020
1 parent 451f766 commit 6a0376c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
18 changes: 9 additions & 9 deletions examples/notebooks/models/compare-lithium-ion.ipynb

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions examples/notebooks/models/lead-acid.ipynb

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions examples/scripts/compare_lead_acid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

# load models
models = [
pybamm.lead_acid.LOQS(),
pybamm.lead_acid.FOQS(),
pybamm.lead_acid.CompositeExtended(),
pybamm.lead_acid.Full(),
# pybamm.lead_acid.LOQS(),
# pybamm.lead_acid.FOQS(),
pybamm.lead_acid.Composite(),
# pybamm.lead_acid.Full(),
]

# load parameter values and process models and geometry
param = models[0].default_parameter_values
param.update({"Current function [A]": 17, "Initial State of Charge": 1})
param.update({"Current function [A]": "[input]", "Initial State of Charge": 1})
for model in models:
param.process_model(model)

Expand All @@ -43,7 +43,9 @@
solutions = [None] * len(models)
t_eval = np.linspace(0, 3600 * 2, 1000)
for i, model in enumerate(models):
solution = model.default_solver.solve(model, t_eval)
solution = model.default_solver.solve(
model, t_eval, inputs={"Current function [A]": 1}
)
solutions[i] = solution

# plot
Expand Down
15 changes: 10 additions & 5 deletions pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,20 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None):
raise KeyError("Input parameter '{}' not found".format(self.name))

if isinstance(input_eval, numbers.Number):
input_shape = 1
input_size = 1
input_ndim = 0
else:
input_shape = input_eval.shape[0]
if input_shape == self._expected_size:
return input_eval
input_size = input_eval.shape[0]
input_ndim = len(input_eval.shape)
if input_size == self._expected_size:
if input_ndim == 1:
return input_eval[:, np.newaxis]
else:
return input_eval
else:
raise ValueError(
"Input parameter '{}' was given an object of size '{}'".format(
self.name, input_shape
self.name, input_size
)
+ " but was expecting an object of size '{}'.".format(
self._expected_size
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_set_expected_size(self):
self.assertEqual(a._expected_size, 10)
np.testing.assert_array_equal(a.evaluate(inputs="shape test"), np.ones((10, 1)))
y = np.linspace(0, 1, 10)
np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y)
np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y[:, np.newaxis])
with self.assertRaisesRegex(
ValueError,
"Input parameter 'a' was given an object of size '1' but was expecting an "
Expand Down

0 comments on commit 6a0376c

Please sign in to comment.