Skip to content

Commit

Permalink
#1477 fix bug in jax evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 18, 2021
1 parent df0ff95 commit 2cb99ad
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,10 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = result.reshape(result.shape[0], -1)

return result
if known_evals is not None:
return result, known_evals
else:
return result


class EvaluatorJaxSensitivities:
Expand All @@ -704,4 +707,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)

return result
if known_evals is not None:
return result, known_evals
else:
return result

0 comments on commit 2cb99ad

Please sign in to comment.