Skip to content

Commit

Permalink
#1477 sorting out processed variable
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 16, 2021
1 parent ac94921 commit f5699c4
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 112 deletions.
34 changes: 9 additions & 25 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@ class BaseSolver(object):
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
sensitivity : str, optional
Whether (and how) to calculate sensitivities when solving. Options are:
- None (default): the individual solver is responsible for
calculating the sensitivity wrt these parameters, and providing the result in
the solution instance returned. At the moment this is only implemented for the
IDAKLU solver.\
- "explicit forward": explicitly formulate the sensitivity equations for
the chosen input parameters. . At the moment this is only
implemented using convert_to_format = 'casadi'. \
- see individual solvers for other options
"""

def __init__(
Expand Down Expand Up @@ -231,6 +221,8 @@ def set_up(self, model, inputs=None, t_eval=None,
else:
calculate_sensitivites = []

self.calculate_sensitivites = calculate_sensitivites

calculate_sensitivites_explicit = False
if calculate_sensitivites and not isinstance(self, pybamm.IDAKLUSolver):
calculate_sensitivites_explicit = True
Expand Down Expand Up @@ -360,12 +352,13 @@ def jacp(*args, **kwargs):
# Add sensitivity vectors to the rhs and algebraic equations
jacp = None
if calculate_sensitivites_explicit:
print('CASADI EXPLICIT', name, model.len_rhs)
# The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
# & Moura, S. (2018). Optimal experimental design for
# parameterization of an electrochemical lithium-ion battery model.
# Journal of The Electrochemical Society, 165(7), A1309.". See #1100
# for details
if name == "rhs" and model.len_rhs > 0:
if name == "RHS" and model.len_rhs > 0:
report(
"Creating explicit forward sensitivity equations for rhs using CasADi")
df_dx = casadi.jacobian(func, y_diff)
Expand Down Expand Up @@ -621,7 +614,7 @@ def jacp(*args, **kwargs):

# if we have changed the equations to include the explicit sensitivity
# equations, then we also need to update the mass matrix
if self.sensitivity == "explicit forward":
if calculate_sensitivites_explicit:
n_inputs = len(calculate_sensitivites)
model.mass_matrix_inv = pybamm.Matrix(
block_diag(
Expand Down Expand Up @@ -693,27 +686,21 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
Whether to update the rhs. True for 'solve', False for 'step'.
"""
# Make inputs symbolic if calculating sensitivities with casadi
if self.sensitivity == "casadi":
symbolic_inputs = casadi.MX.sym(
"inputs", casadi.vertcat(*inputs.values()).shape[0]
)
else:
symbolic_inputs = inputs

if self.algebraic_solver is True:
# Don't update model.y0
return None
elif len(model.algebraic) == 0:
if update_rhs is True:
# Recalculate initial conditions for the rhs equations
y0 = model.init_eval(symbolic_inputs)
y0 = model.init_eval(inputs)
else:
# Don't update model.y0
return None
else:
if update_rhs is True:
# Recalculate initial conditions for the rhs equations
y0_from_inputs = model.init_eval(symbolic_inputs)
y0_from_inputs = model.init_eval(inputs)
# Reuse old solution for algebraic equations
y0_from_model = model.y0
len_rhs = model.len_rhs
Expand All @@ -726,10 +713,7 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
)
y0 = self.calculate_consistent_state(model, 0, inputs)
# Make y0 a function of inputs if doing symbolic with casadi
if self.sensitivity == "casadi":
model.y0 = casadi.Function("y0", [symbolic_inputs], [y0])
else:
model.y0 = y0
model.y0 = y0

def calculate_consistent_state(self, model, time=0, inputs=None):
"""
Expand Down
12 changes: 3 additions & 9 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):


# are we solving explicit forward equations?
explicit_sensitivities = self.sensitivity == 'explicit forward'
explicit_sensitivities = bool(self.calculate_sensitivites)

# Record whether there are any symbolic inputs
inputs_dict = inputs_dict or {}
Expand Down Expand Up @@ -603,7 +603,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
pybamm.logger.debug("Running CasADi integrator")

# are we solving explicit forward equations?
explicit_sensitivities = self.sensitivity == 'explicit forward'
explicit_sensitivities = bool(self.calculate_sensitivites)

if use_grid is True:
t_eval_shifted = t_eval - t_eval[0]
Expand All @@ -613,12 +613,6 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
integrator = self.integrators[model]["no grid"]

symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])
# If doing sensitivity with casadi, evaluate with symbolic inputs
# Otherwise, evaluate with actual inputs
if self.sensitivity == "casadi":
inputs_eval = symbolic_inputs
else:
inputs_eval = inputs

len_rhs = model.concatenated_rhs.size

Expand Down Expand Up @@ -656,7 +650,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
for i in range(len(t_eval) - 1):
t_min = t_eval[i]
t_max = t_eval[i + 1]
inputs_with_tlims = casadi.vertcat(inputs_eval, t_min, t_max)
inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max)
timer = pybamm.Timer()
casadi_sol = integrator(
x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call
Expand Down
62 changes: 32 additions & 30 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):

self.all_ts = solution.all_ts
self.all_ys = solution.all_ys
self.all_inputs = solution.all_inputs
self.all_inputs_casadi = solution.all_inputs_casadi

self.mesh = base_variables[0].mesh
Expand All @@ -51,8 +52,8 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
self.u_sol = solution.y

# Sensitivity starts off uninitialized, only set when called
self._sensitivity = None
self.all_sensitivities = solution.all_sensitivities
self._sensitivities = None
self.solution_sensitivities = solution.sensitivities

# Set timescale
self.timescale = solution.timescale_eval
Expand Down Expand Up @@ -488,52 +489,44 @@ def data(self):
"""Same as entries, but different name"""
return self.entries


class Interpolant0D:
def __init__(self, entries):
self.entries = entries

def __call__(self, t):
return self.entries

@property
def sensitivity(self):
def sensitivities(self):
"""
Returns a dictionary of sensitivity for each input parameter.
Returns a dictionary of sensitivities for each input parameter.
The keys are the input parameters, and the value is a matrix of size
(n_x * n_t, n_p), where n_x is the number of states, n_t is the number of time
points, and n_p is the size of the input parameter
"""
# No sensitivity if there are no inputs
if len(self.inputs) == 0:
# No sensitivities if there are no inputs
if len(self.all_inputs[0]) == 0:
return {}
# Otherwise initialise and return sensitivity
if self._sensitivity is None:
if self.solution_sensitivity != {}:
# Otherwise initialise and return sensitivities
if self._sensitivities is None:
if self.solution_sensitivities != {}:
self.initialise_sensitivity_explicit_forward()
else:
raise ValueError(
"Cannot compute sensitivities. The 'sensitivity' argument of the "
"solver should be changed from 'None' to allow sensitivity "
"Cannot compute sensitivities. The 'sensitivities' argument of the "
"solver.solve should be changed from 'None' to allow sensitivities "
"calculations. Check solver documentation for details."
)
return self._sensitivity
return self._sensitivities

def initialise_sensitivity_explicit_forward(self):
"Set up the sensitivity dictionary"
inputs_stacked = casadi.vertcat(*[p for p in self.inputs.values()])
inputs_stacked = self.all_inputs_casadi[0]

# Set up symbolic variables
t_casadi = casadi.MX.sym("t")
y_casadi = casadi.MX.sym("y", self.u_sol.shape[0])
p_casadi = {
name: casadi.MX.sym(name, value.shape[0])
for name, value in self.inputs.items()
for name, value in self.all_inputs[0].items()
}
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])

# Convert variable to casadi format for differentiating
var_casadi = self.base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
var_casadi = self.base_variables[0].to_casadi(t_casadi, y_casadi, inputs=p_casadi)
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)

Expand All @@ -544,8 +537,8 @@ def initialise_sensitivity_explicit_forward(self):
dvar_dp_func = casadi.Function(
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
)
for idx in range(len(self.t_sol)):
t = self.t_sol[idx]
for idx in range(len(self.all_ts[0])):
t = self.all_ts[0][idx]
u = self.u_sol[:, idx]
inp = inputs_stacked[:, idx]
next_dvar_dy_eval = dvar_dy_func(t, u, inp)
Expand All @@ -558,20 +551,29 @@ def initialise_sensitivity_explicit_forward(self):
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)

# Compute sensitivity
dy_dp = self.solution_sensitivity["all"]
dy_dp = self.solution_sensitivities["all"]
S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval

sensitivity = {"all": S_var}
sensitivities = {"all": S_var}

# Add the individual sensitivity
start = 0
for name, inp in self.inputs.items():
for name, inp in self.all_inputs[0].items():
end = start + inp.shape[0]
sensitivity[name] = S_var[:, start:end]
sensitivities[name] = S_var[:, start:end]
start = end

# Save attribute
self._sensitivity = sensitivity
self._sensitivities = sensitivities


class Interpolant0D:
def __init__(self, entries):
self.entries = entries

def __call__(self, t):
return self.entries


class Interpolant1D:
def __init__(self, pts_for_interp, entries_for_interp):
Expand Down
Loading

0 comments on commit f5699c4

Please sign in to comment.