Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add perturbation to initial conditions in casadi solver #2356

Merged
merged 10 commits into from
Oct 13, 2022
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

## Bug fixes

- For simulations with events that cause the simulation to stop early, the sensitivities could be evaluated incorrectly to zero ([#2331](https://github.com/pybamm-team/PyBaMM/pull/2337))
- For simulations with events that cause the simulation to stop early, the sensitivities could be evaluated incorrectly to zero ([#2337](https://github.com/pybamm-team/PyBaMM/pull/2337))

## Optimizations

- Added small perturbation to initial conditions for casadi solver. This seems to help the solver converge better in some cases ([#2356](https://github.com/pybamm-team/PyBaMM/pull/2356))
- Added `ExplicitTimeIntegral` functionality to move variables which do not appear anywhere on the rhs to a new location, and to integrate those variables explicitly when `get` is called by the solution object. ([#2348](https://github.com/pybamm-team/PyBaMM/pull/2348))
- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211))

- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210))
- Added `ExplicitTimeIntegral` functionality to move variables which do not appear anywhere on the rhs to a new location, and to integrate those variables explicitly when `get` is called by the solution object. ([#2348](https://github.com/pybamm-team/PyBaMM/pull/2348))

## Breaking change

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
[![black code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)

<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->

[![All Contributors](https://img.shields.io/badge/all_contributors-46-orange.svg)](#-contributors)

<!-- ALL-CONTRIBUTORS-BADGE:END -->

</div>
Expand Down
34 changes: 30 additions & 4 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class CasadiSolver(pybamm.BaseSolver):
return_solution_if_failed_early : bool, optional
Whether to return a Solution object if the solver fails to reach the end of
the simulation, but managed to take some successful steps. Default is False.
perturb_algebraic_initial_conditions : bool, optional
Whether to perturb algebraic initial conditions to avoid a singularity. This
can sometimes slow down the solver, but is kept True as default for "safe" mode
as it seems to be more robust (False by default for other modes).
"""

def __init__(
Expand All @@ -81,6 +85,7 @@ def __init__(
extra_options_setup=None,
extra_options_call=None,
return_solution_if_failed_early=False,
perturb_algebraic_initial_conditions=None,
):
super().__init__(
"problem dependent",
Expand All @@ -106,6 +111,17 @@ def __init__(
self.extrap_tol = extrap_tol
self.return_solution_if_failed_early = return_solution_if_failed_early

# Decide whether to perturb algebraic initial conditions, True by default for
# "safe" mode, False by default for other modes
if perturb_algebraic_initial_conditions is None:
if mode == "safe":
self.perturb_algebraic_initial_conditions = True
else:
self.perturb_algebraic_initial_conditions = False
else:
self.perturb_algebraic_initial_conditions = (
perturb_algebraic_initial_conditions
)
self.name = "CasADi solver with '{}' mode".format(mode)

# Initialize
Expand Down Expand Up @@ -658,14 +674,22 @@ def _run_integrator(
integrator = self.integrators[model]["no grid"]

len_rhs = model.concatenated_rhs.size
len_alg = model.concatenated_algebraic.size

# Check y0 to see if it includes sensitivities
if explicit_sensitivities:
num_parameters = model.len_rhs_sens // model.len_rhs
len_rhs = len_rhs * (num_parameters + 1)
len_alg = len_alg * (num_parameters + 1)

y0_diff = y0[:len_rhs]
y0_alg = y0[len_rhs:]
if self.perturb_algebraic_initial_conditions and len_alg > 0:
# Add a tiny perturbation to the algebraic initial conditions
# For some reason this helps with convergence
# The actual value of the initial conditions for the algebraic variables
# doesn't matter
y0_alg = y0_alg * (1 + 1e-6 * casadi.DM(np.random.rand(len_alg)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow users to fiddle with this value or is it very sensitive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that sensitive. It's just the first thing I tried and it worked for the specific case I was testing. We can see if other cases come up whether changing that number makes a difference

pybamm.logger.spam("Finished preliminary setup for integrator run")

# Solve
Expand All @@ -680,9 +704,10 @@ def _run_integrator(
casadi_sol = integrator(
x0=y0_diff, z0=y0_alg, p=inputs_with_tmin, **self.extra_options_call
)
except RuntimeError as e:
except RuntimeError as error:
# If it doesn't work raise error
raise pybamm.SolverError(e.args[0])
pybamm.logger.debug(f"Casadi integrator failed with error {error}")
raise pybamm.SolverError(error.args[0])
pybamm.logger.debug("Finished casadi integrator")
integration_time = timer.time()
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
Expand Down Expand Up @@ -711,9 +736,10 @@ def _run_integrator(
casadi_sol = integrator(
x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call
)
except RuntimeError as e:
except RuntimeError as error:
# If it doesn't work raise error
raise pybamm.SolverError(e.args[0])
pybamm.logger.debug(f"Casadi integrator failed with error {error}")
raise pybamm.SolverError(error.args[0])
integration_time = timer.time()
x = casadi_sol["xf"]
z = casadi_sol["zf"]
Expand Down
5 changes: 0 additions & 5 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ def test_solving(
self.solver.rtol = 1e-8
self.solver.atol = 1e-8

# Somehow removing an equation makes the solver fail at
# the low tolerances
if isinstance(self.model, pybamm.lithium_ion.NewmanTobias):
self.solver.rtol = 1e-7

Crate = abs(
self.parameter_values["Current function [A]"]
/ self.parameter_values["Nominal cell capacity [A.h]"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,11 @@ def test_run_experiment_skip_steps(self):
model, parameter_values=parameter_values, experiment=experiment2
)
sol2 = sim2.solve()
np.testing.assert_array_equal(
np.testing.assert_array_almost_equal(
sol["Terminal voltage [V]"].data, sol2["Terminal voltage [V]"].data
)
for idx1, idx2 in [(1, 0), (2, 1), (4, 2)]:
np.testing.assert_array_equal(
np.testing.assert_array_almost_equal(
sol.cycles[0].steps[idx1]["Terminal voltage [V]"].data,
sol2.cycles[0].steps[idx2]["Terminal voltage [V]"].data,
)
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def test_model_solver(self):
disc = pybamm.Discretisation()
model_disc = disc.process_model(model, inplace=False)
# Solve
solver = pybamm.CasadiSolver(mode="fast", rtol=1e-8, atol=1e-8)
solver = pybamm.CasadiSolver(
mode="fast",
rtol=1e-8,
atol=1e-8,
perturb_algebraic_initial_conditions=False, # added for coverage
)
t_eval = np.linspace(0, 1, 100)
solution = solver.solve(model_disc, t_eval)
np.testing.assert_array_equal(solution.t, t_eval)
Expand Down