Skip to content

Commit

Permalink
#1477 check sensitivities with fd in integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Aug 2, 2021
1 parent 03528da commit 06984c4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
33 changes: 24 additions & 9 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,18 +608,33 @@ def jacp(*args, **kwargs):

# if we have changed the equations to include the explicit sensitivity
# equations, then we also need to update the mass matrix
n_inputs = model.len_rhs_sens // model.len_rhs
n_state_without_sens = model.len_rhs_and_alg
if calculate_sensitivities_explicit:
n_inputs = model.len_rhs_sens // model.len_rhs
model.mass_matrix_inv = pybamm.Matrix(
block_diag(
[model.mass_matrix_inv.entries] * (n_inputs + 1), format="csr"
if model.mass_matrix.shape[0] == n_state_without_sens:
model.mass_matrix_inv = pybamm.Matrix(
block_diag(
[model.mass_matrix_inv.entries] * (n_inputs + 1),
format="csr"
)
)
)
model.mass_matrix = pybamm.Matrix(
block_diag(
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
model.mass_matrix = pybamm.Matrix(
block_diag(
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
)
)
else:
# take care if calculate_sensitivites used then not used
n_state_with_sens = model.len_rhs_and_alg * (n_inputs + 1)
if model.mass_matrix.shape[0] == n_state_with_sens:
model.mass_matrix_inv = pybamm.Matrix(
model.mass_matrix_inv.entries[:n_state_without_sens,
:n_state_without_sens]
)
model.mass_matrix = pybamm.Matrix(
model.mass_matrix.entries[:n_state_without_sens,
:n_state_without_sens]
)
)

# Save CasADi functions for the CasADi solver
# Note: when we pass to casadi the ode part of the problem must be in
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,25 @@ def test_sensitivities(self):

self.test_processing_parameters()
self.test_processing_disc()

self.test_solving(inputs=inputs, calculate_sensitivities=True)

# check via finite differencing
h = 1e-6
inputs_plus = {param_name: neg_electrode_cond + 0.5 * h}
inputs_neg = {param_name: neg_electrode_cond - 0.5 * h}
sol_plus = self.solver.solve(
self.model, self.solution.all_ts[0], inputs=inputs_plus
)
sol_neg = self.solver.solve(
self.model, self.solution.all_ts[0], inputs=inputs_neg
)
n = self.solution.sensitivities[param_name].shape[0]
np.testing.assert_array_almost_equal(
self.solution.sensitivities[param_name],
((sol_plus.y - sol_neg.y) / h).reshape((n, 1))
)

if (
isinstance(
self.model, (pybamm.lithium_ion.BaseModel, pybamm.lead_acid.BaseModel)
Expand Down

0 comments on commit 06984c4

Please sign in to comment.