Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1343 update get_termination_reason to return solution #1344

Merged
merged 4 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

## Bug fixes

- Fixed a bug where the event time and state were no longer returned as part of the solution ([#1344](https://github.com/pybamm-team/PyBaMM/pull/1344))
- Fixed a bug in `CasadiSolver` safe mode which crashed when there were extrapolation events but no termination events ([#1321](https://github.com/pybamm-team/PyBaMM/pull/1321))
- When an `Interpolant` is extrapolated an error is raised for `CasadiSolver` (and a warning is raised for the other solvers) ([#1315](https://github.com/pybamm-team/PyBaMM/pull/1315))
- Fixed `Simulation` and `model.new_copy` to fix a bug where changes to the model were overwritten ([#1278](https://github.com/pybamm-team/PyBaMM/pull/1278))
Expand Down
114 changes: 71 additions & 43 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def solve(
t_eval = np.array([0])
else:
raise ValueError("t_eval cannot be None")

# If t_eval is provided as [t0, tf] return the solution at 100 points
elif isinstance(t_eval, list):
if len(t_eval) == 1 and self.algebraic_solver is True:
Expand Down Expand Up @@ -611,10 +612,8 @@ def solve(
'when model in format "jax".'
)

# Set up
timer = pybamm.Timer()

# Set up (if not done already)
timer = pybamm.Timer()
if model not in self.models_set_up:
# It is assumed that when len(inputs_list) > 1, model set
# up (initial condition, time-scale and length-scale) does
Expand Down Expand Up @@ -723,9 +722,9 @@ def solve(
)
end_indices.append(len(t_eval_dimensionless))

# integrate separately over each time segment and accumulate into the solution
# Integrate separately over each time segment and accumulate into the solution
# object, restarting the solver at each discontinuity (and recalculating a
# consistent state afterwards if a dae)
# consistent state afterwards if a DAE)
old_y0 = model.y0
solutions = None
for start_index, end_index in zip(start_indices, end_indices):
Expand Down Expand Up @@ -756,8 +755,7 @@ def solve(
p.close()
p.join()
# Setting the solve time for each segment.
# pybamm.Solution.append assumes attribute
# solve_time.
# pybamm.Solution.__add__ assumes attribute solve_time.
solve_time = timer.time()
for sol in new_solutions:
sol.solve_time = solve_time
Expand All @@ -780,41 +778,59 @@ def solve(
model.y0 = self.calculate_consistent_state(
model, t_eval_dimensionless[end_index], ext_and_inputs_list[0]
)

solve_time = timer.time()
for i, solution in enumerate(solutions):
# Assign times
solution.set_up_time = set_up_time
solution.solve_time = solve_time

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
for i, solution in enumerate(solutions):
# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
)
# Identify the event that caused termination and update the solution to
# include the event time and state
solutions[i], termination = self.get_termination_reason(
solution, model.events
)
# Assign times
solutions[i].set_up_time = set_up_time
# all solutions get the same solve time, but their integration time
# will be different (see https://github.com/pybamm-team/PyBaMM/pull/1261)
solutions[i].solve_time = solve_time

# Identify the event that caused termination
termination = self.get_termination_reason(solutions[0], model.events)

# restore old y0
# Restore old y0
model.y0 = old_y0

pybamm.logger.info("Finish solving {} ({})".format(model.name, termination))
pybamm.logger.info(
(
"Set-up time: {}, Solve time: {} (of which integration time: {}), "
"Total time: {}"
).format(
solutions[0].set_up_time,
solutions[0].solve_time,
solutions[0].integration_time,
solutions[0].total_time,
# Report times
if len(solutions) == 1:
pybamm.logger.info("Finish solving {} ({})".format(model.name, termination))
pybamm.logger.info(
(
"Set-up time: {}, Solve time: {} (of which integration time: {}), "
"Total time: {}"
).format(
solutions[0].set_up_time,
solutions[0].solve_time,
solutions[0].integration_time,
solutions[0].total_time,
)
)
else:
pybamm.logger.info("Finish solving {} for all inputs".format(model.name))
pybamm.logger.info(
(
"Set-up time: {}, Solve time: {} (of which integration time: {}), "
"Total time: {}"
).format(
solutions[0].set_up_time,
sum([sol.solve_time for sol in solutions]),
sum([sol.integration_time for sol in solutions]),
sum([sol.total_time for sol in solutions]),
)
)
)

# Raise error if solutions[0] only contains one timestep (except for algebraic
# solvers, where we may only expect one time in the solution)
Expand All @@ -828,6 +844,7 @@ def solve(
"Check whether simulation terminated too early."
)

# Return solution(s)
if ninputs == 1:
return solutions[0]
else:
Expand Down Expand Up @@ -925,6 +942,7 @@ def step(
"parameter and the value has changed between "
"steps!".format(domain)
)

# Run set up on first step
if old_solution is None:
pybamm.logger.verbose(
Expand Down Expand Up @@ -954,9 +972,6 @@ def step(
)
timer.reset()
solution = self._integrate(model, t_eval, ext_and_inputs)

# Assign times
solution.set_up_time = set_up_time
solution.solve_time = timer.time()

# Check if extrapolation occurred
Expand All @@ -969,9 +984,14 @@ def step(
pybamm.SolverWarning,
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)
# Identify the event that caused termination and update the solution to
# include the event time and state
solution, termination = self.get_termination_reason(solution, model.events)

# Assign setup time
solution.set_up_time = set_up_time

# Report times
pybamm.logger.verbose("Finish stepping {} ({})".format(model.name, termination))
pybamm.logger.verbose(
(
Expand All @@ -984,6 +1004,8 @@ def step(
solution.total_time,
)
)

# Return solution
if save is False or old_solution is None:
return solution
else:
Expand All @@ -992,7 +1014,8 @@ def step(
def get_termination_reason(self, solution, events):
"""
Identify the cause for termination. In particular, if the solver terminated
due to an event, (try to) pinpoint which event was responsible.
due to an event, (try to) pinpoint which event was responsible. If an event
occurs the event time and state are added to the solution object.
Note that the current approach (evaluating all the events and then finding which
one is smallest at the final timestep) is pretty crude, but is the easiest one
that works for all the different solvers.
Expand All @@ -1005,7 +1028,10 @@ def get_termination_reason(self, solution, events):
Dictionary of events
"""
if solution.termination == "final time":
return "the solver successfully reached the end of the integration interval"
return (
solution,
"the solver successfully reached the end of the integration interval",
)
elif solution.termination == "event":
# Get final event value
final_event_values = {}
Expand Down Expand Up @@ -1039,7 +1065,9 @@ def get_termination_reason(self, solution, events):
event_sol.integration_time = 0
solution = solution + event_sol

return solution.termination
return solution, solution.termination
elif solution.termination == "success":
return solution, solution.termination

def check_extrapolation(self, solution, events):
"""
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def test_model_solver_events(self):
solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_array_less(solution.y.full()[0], 1.5)
np.testing.assert_array_less(solution.y.full()[-1], 2.5 + 1e-10)
np.testing.assert_array_less(solution.y.full()[0, :-1], 1.5)
np.testing.assert_array_less(solution.y.full()[-1, :-1], 2.5)
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y.full()[:, -1])
np.testing.assert_array_almost_equal(
solution.y.full()[0], np.exp(0.1 * solution.t), decimal=5
)
Expand Down Expand Up @@ -277,8 +279,12 @@ def test_model_step_events(self):
while time < end_time:
step_solution = step_solver.step(step_solution, model, dt=dt, npts=10)
time += dt
np.testing.assert_array_less(step_solution.y.full()[0], 1.5)
np.testing.assert_array_less(step_solution.y.full()[-1], 2.5001)
np.testing.assert_array_less(step_solution.y.full()[0, :-1], 1.5)
np.testing.assert_array_less(step_solution.y.full()[-1, :-1], 2.5)
np.testing.assert_equal(step_solution.t_event[0], step_solution.t[-1])
np.testing.assert_array_equal(
step_solution.y_event[:, 0], step_solution.y.full()[:, -1]
)
np.testing.assert_array_almost_equal(
step_solution.y.full()[0], np.exp(0.1 * step_solution.t), decimal=5
)
Expand Down
61 changes: 41 additions & 20 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def test_model_solver_ode_events_python(self):
t_eval = np.linspace(0, 10, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[0], 1.25 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[0, :-1], 1.25)
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])

def test_model_solver_ode_jacobian_python(self):
model = pybamm.BaseModel()
Expand Down Expand Up @@ -251,10 +253,12 @@ def test_model_solver_dae_events_python(self):
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[-1, :-1], 2.5)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])

def test_model_solver_dae_nonsmooth_python(self):
model = pybamm.BaseModel()
Expand Down Expand Up @@ -323,8 +327,8 @@ def nonsmooth_mult(t):

# check solution
for solution in [solution1, solution2]:
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[-1, :-1], 2.5)
var1_soln = np.exp(0.2 * solution.t)
y0 = np.exp(0.2 * discontinuity)
var1_soln[solution.t > discontinuity] = y0 * np.exp(
Expand Down Expand Up @@ -390,8 +394,8 @@ def test_model_solver_dae_multiple_nonsmooth_python(self):

# check solution
for solution in [solution1, solution2]:
np.testing.assert_array_less(solution.y[0], 0.55 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 1.2 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 0.55)
np.testing.assert_array_less(solution.y[-1, :-1], 1.2)
var1_soln = (solution.t % a) ** 2 / 2 + a ** 2 / 2 * (solution.t // a)
var2_soln = 2 * var1_soln
np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06)
Expand Down Expand Up @@ -571,8 +575,10 @@ def test_model_solver_ode_events_casadi(self):
t_eval = np.linspace(0, 10, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[0], 1.25 + 1e-6)
np.testing.assert_array_less(solution.y[0:, -1], 1.5)
np.testing.assert_array_less(solution.y[0:, -1], 1.25 + 1e-6)
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])

def test_model_solver_dae_events_casadi(self):
# Create model
Expand All @@ -597,8 +603,10 @@ def test_model_solver_dae_events_casadi(self):
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model_disc, t_eval)
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[-1, :-1], 2.5)
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))

Expand Down Expand Up @@ -627,8 +635,11 @@ def test_model_solver_dae_inputs_events(self):
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval, inputs={"rate 1": 0.1, "rate 2": 2})
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[-1, :-1], 2.5)
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])
np.testing.assert_equal(solution.t_event[0], solution.t[-1])

np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))

Expand Down Expand Up @@ -732,8 +743,12 @@ def test_model_step_events(self):
while time < end_time:
step_solution = step_solver.step(step_solution, model, dt=dt, npts=10)
time += dt
np.testing.assert_array_less(step_solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(step_solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(step_solution.y[0, :-1], 1.5)
np.testing.assert_array_less(step_solution.y[-1, :-1], 2.5)
np.testing.assert_equal(step_solution.t_event[0], step_solution.t[-1])
np.testing.assert_array_equal(
step_solution.y_event[:, 0], step_solution.y[:, -1]
)
np.testing.assert_array_almost_equal(
step_solution.y[0], np.exp(0.1 * step_solution.t), decimal=5
)
Expand Down Expand Up @@ -773,8 +788,12 @@ def test_model_step_nonsmooth_events(self):
while time < end_time:
step_solution = step_solver.step(step_solution, model, dt=dt, npts=10)
time += dt
np.testing.assert_array_less(step_solution.y[0], 0.55 + 1e-6)
np.testing.assert_array_less(step_solution.y[-1], 1.2 + 1e-6)
np.testing.assert_array_less(step_solution.y[0, :-1], 0.55)
np.testing.assert_array_less(step_solution.y[-1, :-1], 1.2)
np.testing.assert_equal(step_solution.t_event[0], step_solution.t[-1])
np.testing.assert_array_equal(
step_solution.y_event[:, 0], step_solution.y[:, -1]
)
var1_soln = (step_solution.t % a) ** 2 / 2 + a ** 2 / 2 * (step_solution.t // a)
var2_soln = 2 * var1_soln
np.testing.assert_array_almost_equal(step_solution.y[0], var1_soln, decimal=5)
Expand Down Expand Up @@ -856,8 +875,10 @@ def nonsmooth_rate(t):

# check solution
for solution in [solution1, solution2]:
np.testing.assert_array_less(solution.y[0], 1.5 + 1e-6)
np.testing.assert_array_less(solution.y[-1], 2.5 + 1e-6)
np.testing.assert_array_less(solution.y[0, :-1], 1.5)
np.testing.assert_array_less(solution.y[-1, :-1], 2.5)
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])
var1_soln = np.exp(0.2 * solution.t)
y0 = np.exp(0.2 * discontinuity)
var1_soln[solution.t > discontinuity] = y0 * np.exp(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def test_model_solver_with_event_python(self):
self.assertLess(len(solution.t), len(t_eval))
np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1])
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t))
np.testing.assert_equal(solution.t_event[0], solution.t[-1])
np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1])

def test_model_solver_ode_with_jacobian_python(self):
# Create model
Expand Down