From 35bcb78e6b4f9671b802569a1e93b8d0ce980753 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Fri, 13 Sep 2024 16:44:30 +0100 Subject: [PATCH] feat: support sensitivities for pybamm.Simulation and pybamm.Experiment (#4415) * main changes relate to updating the `BaseSolver.step` function to support this * `BaseSolver.step` now can use the input Solution to initialise the sensitivities for the new step --------- Co-authored-by: Eric G. Kratz --- CHANGELOG.md | 3 + src/pybamm/simulation.py | 50 +++- src/pybamm/solvers/base_solver.py | 158 ++++++++++-- .../c_solvers/idaklu/IDAKLUSolverOpenMP.inl | 6 +- src/pybamm/solvers/casadi_algebraic_solver.py | 2 +- src/pybamm/solvers/casadi_solver.py | 8 +- src/pybamm/solvers/idaklu_solver.py | 2 +- src/pybamm/solvers/processed_variable.py | 70 +++--- src/pybamm/solvers/scipy_solver.py | 2 +- src/pybamm/solvers/solution.py | 227 +++++++++++++----- .../test_simulation_with_experiment.py | 77 ++++++ tests/unit/test_simulation.py | 43 ++++ tests/unit/test_solvers/test_casadi_solver.py | 2 +- tests/unit/test_solvers/test_solution.py | 36 ++- 14 files changed, 553 insertions(+), 133 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 701603584c..80cda39db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Features +- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415)) + ## Optimizations - Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416)) diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index 0aa85d1c20..da0ac08316 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -174,7 +174,7 @@ def _set_random_seed(self): % (2**32) ) - def set_up_and_parameterise_experiment(self): + def set_up_and_parameterise_experiment(self, solve_kwargs=None): """ Create and parameterise the models for each step in the experiment. @@ -182,6 +182,46 @@ def set_up_and_parameterise_experiment(self): reduces simulation time since the model formulation is efficient. """ parameter_values = self._parameter_values.copy() + + # some parameters are used to control the experiment, and should not be + # input parameters + restrict_list = {"Initial temperature [K]", "Ambient temperature [K]"} + for step in self.experiment.steps: + if issubclass(step.__class__, pybamm.experiment.step.BaseStepImplicit): + restrict_list.update(step.get_parameter_values([]).keys()) + elif issubclass(step.__class__, pybamm.experiment.step.BaseStepExplicit): + restrict_list.update(["Current function [A]"]) + for key in restrict_list: + if key in parameter_values.keys() and isinstance( + parameter_values[key], pybamm.InputParameter + ): + raise pybamm.ModelError( + f"Cannot use '{key}' as an input parameter in this experiment. " + f"This experiment is controlled via the following parameters: {restrict_list}. " + f"None of these parameters are able to be input parameters." + ) + + if ( + solve_kwargs is not None + and "calculate_sensitivities" in solve_kwargs + and solve_kwargs["calculate_sensitivities"] + ): + for step in self.experiment.steps: + if any( + [ + isinstance( + term, + pybamm.experiment.step.step_termination.BaseTermination, + ) + for term in step.termination + ] + ): + pybamm.logger.warning( + f"Step '{step}' has a termination condition based on an event. Sensitivity calculation will be inaccurate " + "if the time of each step event changes rapidly with respect to the parameters. " + ) + break + # Set the initial temperature to be the temperature of the first step # We can set this globally for all steps since any subsequent steps will either # start at the temperature at the end of the previous step (if non-isothermal @@ -303,7 +343,7 @@ def build(self, initial_soc=None, inputs=None): # rebuilt model so clear solver setup self._solver._model_set_up = {} - def build_for_experiment(self, initial_soc=None, inputs=None): + def build_for_experiment(self, initial_soc=None, inputs=None, solve_kwargs=None): """ Similar to :meth:`Simulation.build`, but for the case of simulating an experiment, where there may be several models and solvers to build. @@ -314,7 +354,7 @@ def build_for_experiment(self, initial_soc=None, inputs=None): if self.steps_to_built_models: return else: - self.set_up_and_parameterise_experiment() + self.set_up_and_parameterise_experiment(solve_kwargs) # Can process geometry with default parameter values (only electrical # parameters change between parameter values) @@ -497,7 +537,9 @@ def solve( elif self.operating_mode == "with experiment": callbacks.on_experiment_start(logs) - self.build_for_experiment(initial_soc=initial_soc, inputs=inputs) + self.build_for_experiment( + initial_soc=initial_soc, inputs=inputs, solve_kwargs=kwargs + ) if t_eval is not None: pybamm.logger.warning( "Ignoring t_eval as solution times are specified by the experiment" diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 9c0d94f1a9..1df9aef35f 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -670,6 +670,33 @@ def calculate_consistent_state(self, model, time=0, inputs=None): y0 = root_sol.all_ys[0] return y0 + def _solve_process_calculate_sensitivities_arg( + inputs, model, calculate_sensitivities + ): + # get a list-only version of calculate_sensitivities + if isinstance(calculate_sensitivities, bool): + if calculate_sensitivities: + calculate_sensitivities_list = [p for p in inputs.keys()] + else: + calculate_sensitivities_list = [] + else: + calculate_sensitivities_list = calculate_sensitivities + + calculate_sensitivities_list.sort() + if not hasattr(model, "calculate_sensitivities"): + model.calculate_sensitivities = [] + + # Check that calculate_sensitivites have not been updated + sensitivities_have_changed = ( + calculate_sensitivities_list != model.calculate_sensitivities + ) + + # save sensitivity parameters so we can identify them later on + # (FYI: this is used in the Solution class) + model.calculate_sensitivities = calculate_sensitivities_list + + return calculate_sensitivities_list, sensitivities_have_changed + def solve( self, model, @@ -700,7 +727,11 @@ def solve( calculate_sensitivities : list of str or bool, optional Whether the solver calculates sensitivities of all input parameters. Defaults to False. If only a subset of sensitivities are required, can also pass a - list of input parameter names + list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances + so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you + require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities + do not take into account the movement of step-transitions wrt input parameters, so do not use this feature + if the timings of your experimental protocol change rapidly with respect to your input parameters. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to None. Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). @@ -722,15 +753,6 @@ def solve( """ pybamm.logger.info(f"Start solving {model.name} with {self.name}") - # get a list-only version of calculate_sensitivities - if isinstance(calculate_sensitivities, bool): - if calculate_sensitivities: - calculate_sensitivities_list = [p for p in inputs.keys()] - else: - calculate_sensitivities_list = [] - else: - calculate_sensitivities_list = calculate_sensitivities - # Make sure model isn't empty self._check_empty_model(model) @@ -772,6 +794,12 @@ def solve( self._set_up_model_inputs(model, inputs) for inputs in inputs_list ] + calculate_sensitivities_list, sensitivities_have_changed = ( + BaseSolver._solve_process_calculate_sensitivities_arg( + model_inputs_list[0], model, calculate_sensitivities + ) + ) + # (Re-)calculate consistent initialization # Assuming initial conditions do not depend on input parameters # when len(inputs_list) > 1, only `model_inputs_list[0]` @@ -792,13 +820,8 @@ def solve( "for initial conditions." ) - # Check that calculate_sensitivites have not been updated - calculate_sensitivities_list.sort() - if hasattr(model, "calculate_sensitivities"): - model.calculate_sensitivities.sort() - else: - model.calculate_sensitivities = [] - if calculate_sensitivities_list != model.calculate_sensitivities: + # if any setup configuration has changed, we need to re-set up + if sensitivities_have_changed: self._model_set_up.pop(model, None) # CasadiSolver caches its integrators using model, so delete this too if isinstance(self, pybamm.CasadiSolver): @@ -1066,6 +1089,58 @@ def _check_events_with_initialization(t_eval, model, inputs_dict): f"Events {event_names} are non-positive at initial conditions" ) + def _set_sens_initial_conditions_from( + self, solution: pybamm.Solution, model: pybamm.BaseModel + ) -> tuple: + """ + A restricted version of BaseModel.set_initial_conditions_from that only extracts the + sensitivities from a solution object, and only for a model that has been descretised. + This is used when setting the initial conditions for a sensitivity model. + + Parameters + ---------- + solution : :class:`pybamm.Solution` + The solution to use to initialize the model + + model: :class:`pybamm.BaseModel` + The model whose sensitivities to set + + Returns + ------- + + initial_conditions : tuple of ndarray + The initial conditions for the sensitivities, each element of the tuple + corresponds to an input parameter + """ + + ninputs = len(model.calculate_sensitivities) + initial_conditions = tuple([] for _ in range(ninputs)) + solution = solution.last_state + for var in model.initial_conditions: + final_state = solution[var.name] + final_state = final_state.sensitivities + final_state_eval = tuple( + final_state[key] for key in model.calculate_sensitivities + ) + + scale, reference = var.scale.value, var.reference.value + for i in range(ninputs): + scaled_final_state_eval = (final_state_eval[i] - reference) / scale + initial_conditions[i].append(scaled_final_state_eval) + + # Also update the concatenated initial conditions if the model is already + # discretised + # Unpack slices for sorting + y_slices = {var: slce for var, slce in model.y_slices.items()} + slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()] + + # sort equations according to slices + concatenated_initial_conditions = [ + casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))]) + for init in initial_conditions + ] + return concatenated_initial_conditions + def process_t_interp(self, t_interp): # set a variable for this no_interp = (not self.supports_interp) and ( @@ -1092,6 +1167,7 @@ def step( npts=None, inputs=None, save=True, + calculate_sensitivities=False, t_interp=None, ): """ @@ -1117,6 +1193,14 @@ def step( Any input parameters to pass to the model when solving save : bool, optional Save solution with all previous timesteps. Defaults to True. + calculate_sensitivities : list of str or bool, optional + Whether the solver calculates sensitivities of all input parameters. Defaults to False. + If only a subset of sensitivities are required, can also pass a + list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances + so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you + require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities + do not take into account the movement of step-transitions wrt input parameters, so do not use this feature + if the timings of your experimental protocol change rapidly with respect to your input parameters. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to None. Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). @@ -1188,8 +1272,15 @@ def step( # Set up inputs model_inputs = self._set_up_model_inputs(model, inputs) + # process calculate_sensitivities argument + calculate_sensitivities_list, sensitivities_have_changed = ( + BaseSolver._solve_process_calculate_sensitivities_arg( + model_inputs, model, calculate_sensitivities + ) + ) + first_step_this_model = model not in self._model_set_up - if first_step_this_model: + if first_step_this_model or sensitivities_have_changed: if len(self._model_set_up) > 0: existing_model = next(iter(self._model_set_up)) raise RuntimeError( @@ -1208,18 +1299,45 @@ def step( ): pybamm.logger.verbose(f"Start stepping {model.name} with {self.name}") + using_sensitivities = len(model.calculate_sensitivities) > 0 + if isinstance(old_solution, pybamm.EmptySolution): if not first_step_this_model: # reset y0 to original initial conditions self.set_up(model, model_inputs, ics_only=True) elif old_solution.all_models[-1] == model: - # initialize with old solution - model.y0 = old_solution.all_ys[-1][:, -1] + last_state = old_solution.last_state + model.y0 = last_state.all_ys[0] + if using_sensitivities and isinstance(last_state._all_sensitivities, dict): + full_sens = last_state._all_sensitivities["all"][0] + model.y0S = tuple(full_sens[:, i] for i in range(full_sens.shape[1])) + else: _, concatenated_initial_conditions = model.set_initial_conditions_from( old_solution, return_type="ics" ) model.y0 = concatenated_initial_conditions.evaluate(0, inputs=model_inputs) + if using_sensitivities: + model.y0S = self._set_sens_initial_conditions_from(old_solution, model) + + # hopefully we'll get rid of explicit sensitivities soon so we can remove this + explicit_sensitivities = model.len_rhs_sens > 0 or model.len_alg_sens > 0 + if ( + explicit_sensitivities + and using_sensitivities + and not isinstance(old_solution, pybamm.EmptySolution) + and not old_solution.all_models[-1] == model + ): + y0_list = [] + if model.len_rhs > 0: + y0_list.append(model.y0[: model.len_rhs]) + for s in model.y0S: + y0_list.append(s[: model.len_rhs]) + if model.len_alg > 0: + y0_list.append(model.y0[model.len_rhs :]) + for s in model.y0S: + y0_list.append(s[model.len_rhs :]) + model.y0 = casadi.vertcat(*y0_list) set_up_time = timer.time() diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index 7ed4dcfad8..fd8eb38257 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -445,7 +445,7 @@ Solution IDAKLUSolverOpenMP::solve( } if (sensitivity) { - CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); + CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); } // Store Consistent initialization @@ -478,7 +478,7 @@ Solution IDAKLUSolverOpenMP::solve( bool hit_adaptive = save_adaptive_steps && retval == IDA_SUCCESS; if (sensitivity) { - CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); + CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); } if (hit_tinterp) { @@ -499,7 +499,7 @@ Solution IDAKLUSolverOpenMP::solve( // Reset the states and sensitivities at t = t_val CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); if (sensitivity) { - CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); + CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); } } diff --git a/src/pybamm/solvers/casadi_algebraic_solver.py b/src/pybamm/solvers/casadi_algebraic_solver.py index cf44912952..2dd6f2d341 100644 --- a/src/pybamm/solvers/casadi_algebraic_solver.py +++ b/src/pybamm/solvers/casadi_algebraic_solver.py @@ -170,7 +170,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): model, inputs_dict, termination="final time", - sensitivities=explicit_sensitivities, + all_sensitivities=explicit_sensitivities, ) sol.integration_time = integration_time return sol diff --git a/src/pybamm/solvers/casadi_solver.py b/src/pybamm/solvers/casadi_solver.py index b4ac9d1561..89e20631dd 100644 --- a/src/pybamm/solvers/casadi_solver.py +++ b/src/pybamm/solvers/casadi_solver.py @@ -193,7 +193,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): y0, model, inputs_dict, - sensitivities=False, + all_sensitivities=False, ) solution.solve_time = 0 solution.integration_time = 0 @@ -478,7 +478,7 @@ def integer_bisect(): np.array([t_event]), y_event[:, np.newaxis], "event", - sensitivities=bool(model.calculate_sensitivities), + all_sensitivities=False, ) solution.integration_time = ( coarse_solution.integration_time + dense_step_sol.integration_time @@ -696,7 +696,7 @@ def _run_integrator( y_sol, model, inputs_dict, - sensitivities=extract_sensitivities_in_solution, + all_sensitivities=extract_sensitivities_in_solution, check_solution=False, ) sol.integration_time = integration_time @@ -736,7 +736,7 @@ def _run_integrator( y_sol, model, inputs_dict, - sensitivities=extract_sensitivities_in_solution, + all_sensitivities=extract_sensitivities_in_solution, check_solution=False, ) sol.integration_time = integration_time diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 41e0c8855f..08f86b3264 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -818,7 +818,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): np.array([sol.t[-1]]), np.transpose(y_event)[:, np.newaxis], termination, - sensitivities=yS_out, + all_sensitivities=yS_out, ) newsol.integration_time = integration_time if not self.output_variables: diff --git a/src/pybamm/solvers/processed_variable.py b/src/pybamm/solvers/processed_variable.py index 8c1190c2f4..2464466348 100644 --- a/src/pybamm/solvers/processed_variable.py +++ b/src/pybamm/solvers/processed_variable.py @@ -66,7 +66,7 @@ def __init__( # Sensitivity starts off uninitialized, only set when called self._sensitivities = None - self.solution_sensitivities = solution.sensitivities + self.all_solution_sensitivities = solution._all_sensitivities # Store time self.t_pts = solution.t @@ -404,7 +404,7 @@ def sensitivities(self): return {} # Otherwise initialise and return sensitivities if self._sensitivities is None: - if self.solution_sensitivities != {}: + if self.all_solution_sensitivities: self.initialise_sensitivity_explicit_forward() else: raise ValueError( @@ -417,48 +417,54 @@ def sensitivities(self): def initialise_sensitivity_explicit_forward(self): "Set up the sensitivity dictionary" - inputs_stacked = self.all_inputs_casadi[0] - - # Set up symbolic variables - t_casadi = casadi.MX.sym("t") - y_casadi = casadi.MX.sym("y", self.all_ys[0].shape[0]) - p_casadi = { - name: casadi.MX.sym(name, value.shape[0]) - for name, value in self.all_inputs[0].items() - } - - p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) - # Convert variable to casadi format for differentiating - var_casadi = self.base_variables[0].to_casadi( - t_casadi, y_casadi, inputs=p_casadi - ) - dvar_dy = casadi.jacobian(var_casadi, y_casadi) - dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked) - - # Convert to functions and evaluate index-by-index - dvar_dy_func = casadi.Function( - "dvar_dy", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dy] - ) - dvar_dp_func = casadi.Function( - "dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp] - ) - for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)): + all_S_var = [] + for ts, ys, inputs_stacked, inputs, base_variable, dy_dp in zip( + self.all_ts, + self.all_ys, + self.all_inputs_casadi, + self.all_inputs, + self.base_variables, + self.all_solution_sensitivities["all"], + ): + # Set up symbolic variables + t_casadi = casadi.MX.sym("t") + y_casadi = casadi.MX.sym("y", ys.shape[0]) + p_casadi = { + name: casadi.MX.sym(name, value.shape[0]) + for name, value in inputs.items() + } + + p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) + + # Convert variable to casadi format for differentiating + var_casadi = base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi) + dvar_dy = casadi.jacobian(var_casadi, y_casadi) + dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked) + + # Convert to functions and evaluate index-by-index + dvar_dy_func = casadi.Function( + "dvar_dy", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dy] + ) + dvar_dp_func = casadi.Function( + "dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp] + ) for idx, t in enumerate(ts): u = ys[:, idx] next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked) next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked) - if index == 0 and idx == 0: + if idx == 0: dvar_dy_eval = next_dvar_dy_eval dvar_dp_eval = next_dvar_dp_eval else: dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval) dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval) - # Compute sensitivity - dy_dp = self.solution_sensitivities["all"] - S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval + # Compute sensitivity + S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval + all_S_var.append(S_var) + S_var = casadi.vertcat(*all_S_var) sensitivities = {"all": S_var} # Add the individual sensitivity diff --git a/src/pybamm/solvers/scipy_solver.py b/src/pybamm/solvers/scipy_solver.py index 226b096887..daa8f706de 100644 --- a/src/pybamm/solvers/scipy_solver.py +++ b/src/pybamm/solvers/scipy_solver.py @@ -150,7 +150,7 @@ def event_fn(t, y): t_event, y_event, termination, - sensitivities=bool(model.calculate_sensitivities), + all_sensitivities=bool(model.calculate_sensitivities), ) sol.integration_time = integration_time return sol diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index c3c8451634..74d9ce7baf 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -2,6 +2,7 @@ # Solution class # import casadi +import copy import json import numbers import numpy as np @@ -57,11 +58,10 @@ class Solution: the event happens. termination : str String to indicate why the solution terminated - - sensitivities: bool or dict + all_sensitivities: bool or dict of lists True if sensitivities included as the solution of the explicit forwards equations. False if no sensitivities included/wanted. Dict if sensitivities are - provided as a dict of {parameter: sensitivities} pairs. + provided as a dict of {parameter: [sensitivities]} pairs. """ @@ -74,7 +74,7 @@ def __init__( t_event=None, y_event=None, termination="final time", - sensitivities=False, + all_sensitivities=False, check_solution=True, ): if not isinstance(all_ts, list): @@ -98,7 +98,18 @@ def __init__( else: self.all_inputs = all_inputs - self.sensitivities = sensitivities + if isinstance(all_sensitivities, bool): + self._all_sensitivities = all_sensitivities + elif isinstance(all_sensitivities, dict): + self._all_sensitivities = {} + for key, value in all_sensitivities.items(): + if isinstance(value, list): + self._all_sensitivities[key] = value + else: + self._all_sensitivities[key] = [value] + + else: + raise TypeError("sensitivities arg needs to be a bool or dict") # Check no ys are too large if check_solution: @@ -134,47 +145,31 @@ def __init__( # Solution now uses CasADi pybamm.citations.register("Andersson2019") - def extract_explicit_sensitivities(self): - # if we got here, we haven't set y yet - self.set_y() + def has_sensitivities(self) -> bool: + if isinstance(self._all_sensitivities, bool): + return self._all_sensitivities + elif isinstance(self._all_sensitivities, dict): + return len(self._all_sensitivities) > 0 - # extract sensitivities from full y solution - self._y, self._sensitivities = self._extract_explicit_sensitivities( - self.all_models[0], self.y, self.t, self.all_inputs[0] - ) + def extract_explicit_sensitivities(self): + self._all_sensitivities = {} - # make sure we remove all sensitivities from all_ys + # extract sensitivities from each sub-solution for index, (model, ys, ts, inputs) in enumerate( zip(self.all_models, self.all_ys, self.all_ts, self.all_inputs) ): - self._all_ys[index], _ = self._extract_explicit_sensitivities( + self._all_ys[index], sens_segment = self._extract_explicit_sensitivities( model, ys, ts, inputs ) + for key, value in sens_segment.items(): + if key in self._all_sensitivities: + self._all_sensitivities[key] = self._all_sensitivities[key] + [ + value + ] + else: + self._all_sensitivities[key] = [value] - def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): - """ - given a model and a solution y, extracts the sensitivities - - Parameters - -------- - model : :class:`pybamm.BaseModel` - A model that has been already setup by this base solver - y: ndarray - The solution of the full explicit sensitivity equations - t_eval: ndarray - The evaluation times - inputs: dict - parameter inputs - - Returns - ------- - y: ndarray - The solution of the ode/dae in model - sensitivities: dict of (string: ndarray) - A dictionary of parameter names, and the corresponding solution of - the sensitivity equations - """ - + def _extract_sensitivity_matrix(self, model, y): n_states = model.len_rhs_and_alg n_rhs = model.len_rhs n_alg = model.len_alg @@ -185,7 +180,6 @@ def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): n_p = model.len_alg_sens // model.len_alg len_rhs_and_sens = model.len_rhs + model.len_rhs_sens - n_t = len(t_eval) # y gets the part of the solution vector that correspond to the # actual ODE/DAE solution @@ -211,6 +205,8 @@ def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): y_full = y.full() else: y_full = y + + n_t = y.shape[1] ode_sens = y_full[n_rhs:len_rhs_and_sens, :].reshape(n_p, n_rhs, n_t) alg_sens = y_full[len_rhs_and_sens + n_alg :, :].reshape(n_p, n_alg, n_t) # 2. Concatenate into a single 3D matrix with shape (n_p, n_states, n_t) @@ -221,6 +217,44 @@ def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): n_t * n_states, n_p ) + # convert back to casadi (todo: this is not very efficient, should refactor + # to avoid this) + full_sens_matrix = casadi.DM(full_sens_matrix) + + y_dae = np.vstack( + [ + y[: model.len_rhs, :], + y[len_rhs_and_sens : len_rhs_and_sens + model.len_alg, :], + ] + ) + return y_dae, full_sens_matrix + + def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): + """ + given a model and a solution y, extracts the sensitivities + + Parameters + -------- + model : :class:`pybamm.BaseModel` + A model that has been already setup by this base solver + y: ndarray + The solution of the full explicit sensitivity equations + t_eval: ndarray + The evaluation times + inputs: dict + parameter inputs + + Returns + ------- + y: ndarray + The solution of the ode/dae in model + sensitivities: dict of (string: ndarray) + A dictionary of parameter names, and the corresponding solution of + the sensitivity equations + """ + + y_dae, full_sens_matrix = self._extract_sensitivity_matrix(model, y) + # Save the full sensitivity matrix sensitivity = {"all": full_sens_matrix} @@ -234,12 +268,6 @@ def _extract_explicit_sensitivities(self, model, y, t_eval, inputs): sensitivity[name] = full_sens_matrix[:, start:end] start = end - y_dae = np.vstack( - [ - y[: model.len_rhs, :], - y[len_rhs_and_sens : len_rhs_and_sens + model.len_alg, :], - ] - ) return y_dae, sensitivity @property @@ -262,31 +290,56 @@ def y(self): try: return self._y except AttributeError: - self.set_y() - # if y is evaluated before sensitivities then need to extract them - if isinstance(self._sensitivities, bool) and self._sensitivities: + if isinstance(self._all_sensitivities, bool) and self._all_sensitivities: self.extract_explicit_sensitivities() + self.set_y() + return self._y @property def sensitivities(self): """Values of the sensitivities. Returns a dict of param_name: np_array""" - if isinstance(self._sensitivities, bool): - if self._sensitivities: - self.extract_explicit_sensitivities() - else: - self._sensitivities = {} + try: + return self._sensitivities + except AttributeError: + self.set_sensitivities() return self._sensitivities @sensitivities.setter def sensitivities(self, value): - """Updates the sensitivity""" + """Updates the sensitivity if False or True. Raises an error if sensitivities are a dict""" # sensitivities must be a dict or bool - if not isinstance(value, (bool, dict)): - raise TypeError("sensitivities arg needs to be a bool or dict") - self._sensitivities = value + if not isinstance(value, bool): + raise TypeError("sensitivities arg needs to be a bool") + + if isinstance(self._all_sensitivities, dict): + raise NotImplementedError( + "Setting sensitivities is not supported if sensitivities are " + "already provided as a dict of {parameter: sensitivities} pairs." + ) + + self._all_sensitivities = value + + def set_sensitivities(self): + if not self.has_sensitivities(): + self._sensitivities = {} + return + + # extract sensitivities if they are not already extracted + if isinstance(self._all_sensitivities, bool) and self._all_sensitivities: + self.extract_explicit_sensitivities() + + is_casadi = isinstance( + next(iter(self._all_sensitivities.values()))[0], (casadi.DM, casadi.MX) + ) + self._sensitivities = {} + for key, sens in self._all_sensitivities.items(): + if is_casadi: + self._sensitivities[key] = casadi.vertcat(*sens) + else: + self._sensitivities[key] = np.vstack(sens) def set_y(self): try: @@ -374,6 +427,13 @@ def first_state(self): than the full solution when only the first state is needed (e.g. to initialize a model with the solution) """ + if isinstance(self._all_sensitivities, bool): + sensitivities = self._all_sensitivities + elif isinstance(self._all_sensitivities, dict): + sensitivities = {} + n_states = self.all_models[0].len_rhs_and_alg + for key in self._all_sensitivities: + sensitivities[key] = self._all_sensitivities[key][0][-n_states:, :] new_sol = Solution( self.all_ts[0][:1], self.all_ys[0][:, :1], @@ -382,6 +442,7 @@ def first_state(self): None, None, "final time", + all_sensitivities=sensitivities, ) new_sol._all_inputs_casadi = self.all_inputs_casadi[:1] new_sol._sub_solutions = self.sub_solutions[:1] @@ -399,6 +460,13 @@ def last_state(self): than the full solution when only the final state is needed (e.g. to initialize a model with the solution) """ + if isinstance(self._all_sensitivities, bool): + sensitivities = self._all_sensitivities + elif isinstance(self._all_sensitivities, dict): + sensitivities = {} + n_states = self.all_models[-1].len_rhs_and_alg + for key in self._all_sensitivities: + sensitivities[key] = self._all_sensitivities[key][-1][-n_states:, :] new_sol = Solution( self.all_ts[-1][-1:], self.all_ys[-1][:, -1:], @@ -407,10 +475,10 @@ def last_state(self): self.t_event, self.y_event, self.termination, + all_sensitivities=sensitivities, ) new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:] new_sol._sub_solutions = self.sub_solutions[-1:] - new_sol.solve_time = 0 new_sol.integration_time = 0 new_sol.set_up_time = 0 @@ -457,7 +525,7 @@ def set_summary_variables(self, all_summary_variables): def update(self, variables): """Add ProcessedVariables to the dictionary of variables in the solution""" # make sure that sensitivities are extracted if required - if isinstance(self._sensitivities, bool) and self._sensitivities: + if isinstance(self._all_sensitivities, bool) and self._all_sensitivities: self.extract_explicit_sensitivities() # Convert single entry to list @@ -758,6 +826,30 @@ def __add__(self, other): all_ts = self.all_ts + other.all_ts all_ys = self.all_ys + other.all_ys + # sensitivities can be: + # - bool if not using sensitivities or using explicit sensitivities which still + # need to be extracted + # - dict if sensitivities are provided as a dict of {parameter: sensitivities} + # both self and other should have the same type of sensitivities + # OR both can be either False or {} (i.e. no sensitivities) + if isinstance(self._all_sensitivities, bool) and isinstance( + other._all_sensitivities, bool + ): + all_sensitivities = self._all_sensitivities or other._all_sensitivities + elif isinstance(self._all_sensitivities, dict) and isinstance( + other._all_sensitivities, dict + ): + all_sensitivities = self._all_sensitivities + # we can assume that the keys are the same for both solutions + for key in other._all_sensitivities: + all_sensitivities[key] = ( + all_sensitivities[key] + other._all_sensitivities[key] + ) + elif not self._all_sensitivities and not other._all_sensitivities: + all_sensitivities = {} + else: + raise ValueError("Sensitivities must be of the same type") + new_sol = Solution( all_ts, all_ys, @@ -766,15 +858,19 @@ def __add__(self, other): other.t_event, other.y_event, other.termination, - bool(self.sensitivities), + all_sensitivities=all_sensitivities, ) new_sol.closest_event_idx = other.closest_event_idx new_sol._all_inputs_casadi = self.all_inputs_casadi + other.all_inputs_casadi - # Set solution time - new_sol.solve_time = self.solve_time + other.solve_time - new_sol.integration_time = self.integration_time + other.integration_time + # Add timers (if available) + for attr in ["solve_time", "integration_time", "set_up_time"]: + if ( + getattr(self, attr, None) is not None + and getattr(other, attr, None) is not None + ): + setattr(new_sol, attr, getattr(self, attr) + getattr(other, attr)) # Set sub_solutions new_sol._sub_solutions = self.sub_solutions + other.sub_solutions @@ -787,12 +883,14 @@ def __radd__(self, other): def copy(self): new_sol = self.__class__( self.all_ts, - self.all_ys, + # need to copy y in case it is modified by extract explicit sensitivities + [copy.copy(y) for y in self.all_ys], self.all_models, self.all_inputs, self.t_event, self.y_event, self.termination, + self._all_sensitivities, ) new_sol._all_inputs_casadi = self.all_inputs_casadi new_sol._sub_solutions = self.sub_solutions @@ -902,6 +1000,7 @@ def make_cycle_solution( sum_sols.t_event, sum_sols.y_event, sum_sols.termination, + sum_sols._all_sensitivities, ) cycle_solution._all_inputs_casadi = sum_sols.all_inputs_casadi cycle_solution._sub_solutions = sum_sols.sub_solutions diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 3507d6e5c1..4f981ba04c 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -201,6 +201,83 @@ def test_run_experiment_cccv_solvers(self): ) self.assertEqual(solutions[1].termination, "final time") + @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") + def test_solve_with_sensitivities_and_experiment(self): + experiment_2step = pybamm.Experiment( + [ + ( + "Discharge at C/20 for 1 hour", + "Charge at 1 A until 4.1 V", + "Hold at 4.1 V until C/2", + "Discharge at 2 W for 30 min", + "Discharge at 2 W for 30 min", # repeat to cover this case (changes initialisation) + ), + ] + * 2, + ) + + solutions = [] + for solver in [ + pybamm.CasadiSolver(), + pybamm.IDAKLUSolver(), + pybamm.ScipySolver(), + ]: + for calculate_sensitivities in [False, True]: + model = pybamm.lithium_ion.SPM() + param = model.default_parameter_values + input_param_name = "Negative electrode active material volume fraction" + input_param_value = param[input_param_name] + param.update({input_param_name: "[input]"}) + sim = pybamm.Simulation( + model, + experiment=experiment_2step, + solver=solver, + parameter_values=param, + ) + solution = sim.solve( + inputs={input_param_name: input_param_value}, + calculate_sensitivities=calculate_sensitivities, + ) + solutions.append(solution) + + # check solutions are the same, leave out the last solution point as it is slightly different + # for each solve due to numerical errors + # TODO: scipy solver does not work for this experiment, with or without sensitivities, + # so we skip this test for now + for i in range(1, len(solutions) - 2): + np.testing.assert_allclose( + solutions[0]["Voltage [V]"].data[:-1], + solutions[i]["Voltage [V]"](solutions[0].t[:-1]), + rtol=5e-2, + equal_nan=True, + ) + + # check sensitivities are roughly the same. Sundials isn't doing error control on the sensitivities + # by default, and the solution can be quite coarse for quickly changing sensitivities + sens_casadi = ( + solutions[1]["Voltage [V]"] + .sensitivities[input_param_name][:-2] + .full() + .flatten() + ) + sens_idaklu = np.interp( + solutions[1].t[:-2], + solutions[3].t, + solutions[3]["Voltage [V]"] + .sensitivities[input_param_name] + .full() + .flatten(), + ) + rtol = 1e-1 + atol = 1e-2 + error = np.sqrt( + np.sum( + ((sens_casadi - sens_idaklu) / (rtol * np.abs(sens_casadi) + atol)) ** 2 + ) + / len(sens_casadi) + ) + self.assertLess(error, 1.0) + def test_run_experiment_drive_cycle(self): drive_cycle = np.array([np.arange(10), np.arange(10)]).T experiment = pybamm.Experiment( diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index fc9fec9745..becd70cbe4 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -314,6 +314,17 @@ def ocv_with_parameter(sto): sim.solve([0, 3600], inputs={"a": 1}, initial_soc=0.8) assert sim._built_initial_soc == 0.8 + def test_restricted_input_params(self): + model = pybamm.lithium_ion.SPM() + parameter_values = model.default_parameter_values + parameter_values.update({"Initial temperature [K]": "[input]"}) + experiment = pybamm.Experiment(["Discharge at 1C until 2.5 V"]) + sim = pybamm.Simulation( + model, parameter_values=parameter_values, experiment=experiment + ) + with pytest.raises(pybamm.ModelError, match="Initial temperature"): + sim.solve([0, 3600]) + def test_esoh_with_input_param(self): # Test that initial soc works with a relevant input parameter model = pybamm.lithium_ion.DFN({"working electrode": "positive"}) @@ -336,6 +347,38 @@ def test_solve_with_inputs(self): sim.solution.all_inputs[0]["Current function [A]"], 1 ) + def test_solve_with_sensitivities(self): + model = pybamm.lithium_ion.SPM() + param = model.default_parameter_values + param.update({"Current function [A]": "[input]"}) + sim = pybamm.Simulation(model, parameter_values=param) + h = 1e-6 + sol1 = sim.solve( + t_eval=[0, 600], + inputs={"Current function [A]": 1}, + calculate_sensitivities=True, + ) + + # check that the sensitivities are stored + assert "Current function [A]" in sol1.sensitivities + + sol2 = sim.solve(t_eval=[0, 600], inputs={"Current function [A]": 1 + h}) + + # check that the sensitivities are not stored + assert "Current function [A]" not in sol2.sensitivities + + # check that the sensitivities are roughly correct + np.testing.assert_array_almost_equal( + sol1["Terminal voltage [V]"].entries + + h + * sol1["Terminal voltage [V]"] + .sensitivities["Current function [A]"] + .full() + .flatten(), + sol2["Terminal voltage [V]"].entries, + decimal=5, + ) + def test_step_with_inputs(self): dt = 0.001 model = pybamm.lithium_ion.SPM() diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 3e1023c7d4..e6f631392a 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -1022,7 +1022,7 @@ def test_solve_sensitivity_algebraic(self): model, t_eval, inputs={"p": 0.1}, calculate_sensitivities=True ) np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], 0.1 * solution.t) + np.testing.assert_allclose(np.array(solution.y)[0], 0.1 * solution.t) np.testing.assert_allclose( solution.sensitivities["p"], solution.t.reshape(-1, 1), atol=1e-7 ) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 995898e8dd..5a584fabbf 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -52,6 +52,20 @@ def test_errors(self): pybamm.Solution(ts, bad_ys, model, {}) self.assertIn("exceeds the maximum", captured.records[0].getMessage()) + with self.assertRaisesRegex( + TypeError, "sensitivities arg needs to be a bool or dict" + ): + pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities="bad") + + sol = pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities={}) + with self.assertRaisesRegex(TypeError, "sensitivities arg needs to be a bool"): + sol.sensitivities = "bad" + with self.assertRaisesRegex( + NotImplementedError, + "Setting sensitivities is not supported if sensitivities are already provided as a dict", + ): + sol.sensitivities = True + def test_add_solutions(self): # Set up first solution t1 = np.linspace(0, 1) @@ -89,7 +103,7 @@ def test_add_solutions(self): # Add solution already contained in existing solution t3 = np.array([2]) - y3 = np.ones((20, 1)) + y3 = np.ones((1, 1)) sol3 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {"a": 3}) self.assertEqual((sol_sum + sol3).all_ts, sol_sum.copy().all_ts) @@ -111,6 +125,23 @@ def test_add_solutions(self): ): 2 + sol3 + sol1 = pybamm.Solution( + t1, + y1, + pybamm.BaseModel(), + {}, + all_sensitivities={"test": [np.ones((1, 3))]}, + ) + sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {}, all_sensitivities=True) + with self.assertRaisesRegex( + ValueError, "Sensitivities must be of the same type" + ): + sol3 = sol1 + sol2 + sol1 = pybamm.Solution(t1, y3, pybamm.BaseModel(), {}, all_sensitivities=False) + sol2 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {}, all_sensitivities={}) + sol3 = sol1 + sol2 + self.assertFalse(sol3._all_sensitivities) + def test_add_solutions_different_models(self): # Set up first solution t1 = np.linspace(0, 1) @@ -146,7 +177,8 @@ def test_copy(self): sol_copy = sol1.copy() self.assertEqual(sol_copy.all_ts, sol1.all_ts) - self.assertEqual(sol_copy.all_ys, sol1.all_ys) + for ys_copy, ys1 in zip(sol_copy.all_ys, sol1.all_ys): + np.testing.assert_array_equal(ys_copy, ys1) self.assertEqual(sol_copy.all_inputs, sol1.all_inputs) self.assertEqual(sol_copy.all_inputs_casadi, sol1.all_inputs_casadi) self.assertEqual(sol_copy.set_up_time, sol1.set_up_time)