From 06984c4279510e285dbf4f33505c70f95fd3363e Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Mon, 2 Aug 2021 18:04:32 +0100 Subject: [PATCH] #1477 check sensitivities with fd in integration tests --- pybamm/solvers/base_solver.py | 33 ++++++++++++++----- .../test_models/standard_model_tests.py | 17 ++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index deccecbf10..576a397c58 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -608,18 +608,33 @@ def jacp(*args, **kwargs): # if we have changed the equations to include the explicit sensitivity # equations, then we also need to update the mass matrix + n_inputs = model.len_rhs_sens // model.len_rhs + n_state_without_sens = model.len_rhs_and_alg if calculate_sensitivities_explicit: - n_inputs = model.len_rhs_sens // model.len_rhs - model.mass_matrix_inv = pybamm.Matrix( - block_diag( - [model.mass_matrix_inv.entries] * (n_inputs + 1), format="csr" + if model.mass_matrix.shape[0] == n_state_without_sens: + model.mass_matrix_inv = pybamm.Matrix( + block_diag( + [model.mass_matrix_inv.entries] * (n_inputs + 1), + format="csr" + ) ) - ) - model.mass_matrix = pybamm.Matrix( - block_diag( - [model.mass_matrix.entries] * (n_inputs + 1), format="csr" + model.mass_matrix = pybamm.Matrix( + block_diag( + [model.mass_matrix.entries] * (n_inputs + 1), format="csr" + ) + ) + else: + # take care if calculate_sensitivites used then not used + n_state_with_sens = model.len_rhs_and_alg * (n_inputs + 1) + if model.mass_matrix.shape[0] == n_state_with_sens: + model.mass_matrix_inv = pybamm.Matrix( + model.mass_matrix_inv.entries[:n_state_without_sens, + :n_state_without_sens] + ) + model.mass_matrix = pybamm.Matrix( + model.mass_matrix.entries[:n_state_without_sens, + :n_state_without_sens] ) - ) # Save CasADi functions for the CasADi solver # Note: when we pass to casadi the ode part of the problem must be in diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index c61b068495..839e60f1d0 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -101,8 +101,25 @@ def test_sensitivities(self): self.test_processing_parameters() self.test_processing_disc() + self.test_solving(inputs=inputs, calculate_sensitivities=True) + # check via finite differencing + h = 1e-6 + inputs_plus = {param_name: neg_electrode_cond + 0.5 * h} + inputs_neg = {param_name: neg_electrode_cond - 0.5 * h} + sol_plus = self.solver.solve( + self.model, self.solution.all_ts[0], inputs=inputs_plus + ) + sol_neg = self.solver.solve( + self.model, self.solution.all_ts[0], inputs=inputs_neg + ) + n = self.solution.sensitivities[param_name].shape[0] + np.testing.assert_array_almost_equal( + self.solution.sensitivities[param_name], + ((sol_plus.y - sol_neg.y) / h).reshape((n, 1)) + ) + if ( isinstance( self.model, (pybamm.lithium_ion.BaseModel, pybamm.lead_acid.BaseModel)