Skip to content

Commit

Permalink
Merge pull request #700 from pybamm-team/issue-697-abs-tol-vec
Browse files Browse the repository at this point in the history
Issue 697 abs tol vec
  • Loading branch information
Scottmar93 authored Nov 4, 2019
2 parents e8d2b41 + 6a759f1 commit cb106df
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 10 deletions.
14 changes: 5 additions & 9 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
residual_type res, jacobian_type jac, jac_get_type gjd,
jac_get_type gjrv, jac_get_type gjcp, int nnz, event_type event,
int number_of_events, int use_jacobian, np_array rhs_alg_id,
double abs_tol, double rel_tol)
np_array atol_np, double rel_tol)
{
auto t = t_np.unchecked<1>();
auto y0 = y0_np.unchecked<1>();
auto yp0 = yp0_np.unchecked<1>();
auto atol = atol_np.unchecked<1>();

int number_of_states;
number_of_states = y0_np.request().size;
Expand All @@ -240,11 +241,13 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
// set initial value
yval = N_VGetArrayPointer(yy);
ypval = N_VGetArrayPointer(yp);
atval = N_VGetArrayPointer(avtol);
int i;
for (i = 0; i < number_of_states; i++)
{
yval[i] = y0[i];
ypval[i] = yp0[i];
atval[i] = atol[i];
}

// allocate memory for solver
Expand All @@ -256,13 +259,6 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,

// set tolerances
rtol = RCONST(rel_tol);
atval = N_VGetArrayPointer(avtol);

for (i = 0; i < number_of_states; i++)
{
atval[i] =
RCONST(abs_tol); // nb: this can be set differently for each state
}

IDASVtolerances(ida_mem, rtol, avtol);

Expand Down Expand Up @@ -369,7 +365,7 @@ PYBIND11_MODULE(idaklu, m)
py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("get_jac_data"),
py::arg("get_jac_row_vals"), py::arg("get_jac_col_ptr"), py::arg("nnz"),
py::arg("events"), py::arg("number_of_events"), py::arg("use_jacobian"),
py::arg("rhs_alg_id"), py::arg("rtol"), py::arg("atol"),
py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"),
py::return_value_policy::take_ownership);

py::class_<Solution>(m, "solution")
Expand Down
82 changes: 81 additions & 1 deletion pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,85 @@ def __init__(
super().__init__("ida", rtol, atol, root_method, root_tol, max_steps)
self.name = "IDA KLU solver"

def set_atol_by_variable(self, variables_with_tols, model):
"""
A method to set the absolute tolerances in the solver by state variable.
This method modifies self._atol.
Parameters
----------
variables_with_tols : dict
A dictionary with keys that are strings indicating the variable you
wish to set the tolerance of and values that are the tolerances.
model : :class:`pybamm.BaseModel`
The model that is going to be solved.
"""

size = model.concatenated_initial_conditions.size
self._check_atol_type(size)
for var, tol in variables_with_tols.items():
variable = model.variables[var]
if isinstance(variable, pybamm.StateVector):
self.set_state_vec_tol(variable, tol)
elif isinstance(variable, pybamm.Concatenation):
for child in variable.children:
if isinstance(child, pybamm.StateVector):
self.set_state_vec_tol(child, tol)
else:
raise pybamm.SolverError(
"""Can only set tolerances for state variables
or concatenations of state variables"""
)
else:
raise pybamm.SolverError(
"""Can only set tolerances for state variables or
concatenations of state variables"""
)

def set_state_vec_tol(self, state_vec, tol):
"""
A method to set the tolerances in the atol vector of a specific
state variable. This method modifies self._atol
Parameters
----------
state_vec : :class:`pybamm.StateVector`
The state vector to apply to the tolerance to
tol: float
The tolerance value
"""
slices = state_vec.y_slices[0]
self._atol[slices] = tol

def _check_atol_type(self, size):
"""
This method checks that the atol vector is of the right shape and
type.
Parameters
----------
size: int
The length of the atol vector
"""

if isinstance(self._atol, float):
self._atol = self._atol * np.ones(size)
elif isinstance(self._atol, list):
self._atol = np.array(self._atol)
elif isinstance(self._atol, np.ndarray):
pass
else:
raise pybamm.SolverError(
"Absolute tolerances must be a numpy array, float, or list"
)

if self._atol.size != size:
raise pybamm.SolverError(
"""Absolute tolerances must be either a scalar or a numpy arrray
of the same shape at y0"""
)

def integrate(self, residuals, y0, t_eval, events, mass_matrix, jacobian):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand Down Expand Up @@ -76,6 +155,7 @@ def integrate(self, residuals, y0, t_eval, events, mass_matrix, jacobian):
pybamm.SolverError("KLU requires events to be provided")

rtol = self._rtol
self._check_atol_type(y0.size)
atol = self._atol

if jacobian:
Expand Down Expand Up @@ -149,8 +229,8 @@ def rootfn(t, y):
num_of_events,
use_jac,
ids,
rtol,
atol,
rtol,
)

t = sol.t
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ def test_on_spme(self):
solution = pybamm.IDAKLUSolver().solve(model, t_eval)
np.testing.assert_array_less(1, solution.t.size)

def test_set_tol_by_variable(self):
model = pybamm.lithium_ion.SPMe()
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 0.2, 100)
solver = pybamm.IDAKLUSolver()

variable_tols = {"Electrolyte concentration": 1e-3}
solver.set_atol_by_variable(variable_tols, model)

solver.solve(model, t_eval)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ def alg(t, y):
true_solution = 0.1 * solution.t
np.testing.assert_array_almost_equal(solution.y[0, :], true_solution)

def test_set_atol(self):
model = pybamm.lithium_ion.SPMe()
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
solver = pybamm.IDAKLUSolver()

variable_tols = {"Electrolyte concentration": 1e-3}
solver.set_atol_by_variable(variable_tols, model)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit cb106df

Please sign in to comment.