Skip to content

Commit

Permalink
refactor: Jax implementation, FittingProblem.evaluate, adds Fisher In…
Browse files Browse the repository at this point in the history
…formation to OptimisationResult
  • Loading branch information
BradyPlanden committed Nov 4, 2024
1 parent e1b99d7 commit 8f11efe
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 126 deletions.
15 changes: 5 additions & 10 deletions examples/scripts/jax-solver-example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import time

import numpy as np
import pybamm

Expand All @@ -10,8 +8,8 @@

# The IDAKLU, and it's jaxified version perform very well on the DFN with and without
# gradient calculations
solver = pybamm.IDAKLUSolver()
model = pybop.lithium_ion.SPM(parameter_set=parameter_set, solver=solver)
solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solver)

# Fitting parameters
parameters = pybop.Parameters(
Expand All @@ -28,7 +26,7 @@
)

# Define test protocol and generate data
t_eval = np.linspace(0, 300, 100)
t_eval = np.linspace(0, 600, 600)
values = model.predict(
initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)
Expand All @@ -47,19 +45,16 @@

# By selecting a Jax based cost function, the IDAKLU solver will be
# jaxified (wrapped in a Jax compiled expression) and used for optimisation
cost = pybop.JaxLogNormalLikelihood(problem, sigma0=0.002)
cost = pybop.JaxLogNormalLikelihood(problem, sigma0=2e-3)

# Non-gradient optimiser, change to `pybop.AdamW` for gradient-based example
optim = pybop.XNES(
optim = pybop.IRPropMin(
cost,
max_unchanged_iterations=20,
max_iterations=100,
)

start_time = time.time()
results = optim.run()
print(results)
print(f"Total time: {time.time() - start_time}")

# Plot convergence
pybop.plot.convergence(optim)
Expand Down
11 changes: 7 additions & 4 deletions examples/scripts/maximum_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pybamm

import pybop

Expand All @@ -7,10 +8,12 @@
parameter_set.update(
{
"Negative electrode active material volume fraction": 0.63,
"Positive electrode active material volume fraction": 0.51,
"Positive electrode active material volume fraction": 0.62,
}
)
model = pybop.lithium_ion.SPM(parameter_set=parameter_set)
options = {"max_num_steps": int(1e6), "max_error_test_failures": 60}
solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, options=options)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solver)

# Fitting parameters
parameters = pybop.Parameters(
Expand Down Expand Up @@ -57,8 +60,8 @@ def noise(sigma):
signal = ["Voltage [V]", "Bulk open-circuit voltage [V]"]
# Generate problem, cost function, and optimisation class
problem = pybop.FittingProblem(model, parameters, dataset, signal=signal)
likelihood = pybop.GaussianLogLikelihood(problem, sigma0=sigma * 4)
optim = pybop.IRPropMin(
likelihood = pybop.JaxGaussianLogLikelihoodKnownSigma(problem, sigma0=sigma)
optim = pybop.XNES(
likelihood,
max_unchanged_iterations=20,
min_iterations=20,
Expand Down
10 changes: 5 additions & 5 deletions pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@
)
from .costs._weighted_cost import WeightedCost

#
# Experimental
#
from .experimental import BaseJaxCost, JaxSumSquaredError, JaxLogNormalLikelihood, JaxGaussianLogLikelihoodKnownSigma

#
# Optimiser classes
#
Expand Down Expand Up @@ -174,11 +179,6 @@
from . import plot as plot
from .samplers.mcmc_summary import PosteriorSummary

#
# Experimental
#
from .experimental import BaseJaxCost, JaxSumSquaredError, JaxLogNormalLikelihood

#
# Remove any imported modules, so we don't expose them as part of pybop
#
Expand Down
2 changes: 1 addition & 1 deletion pybop/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .jax_costs import BaseJaxCost, JaxLogNormalLikelihood, JaxSumSquaredError
from .jax_costs import BaseJaxCost, JaxLogNormalLikelihood, JaxSumSquaredError, JaxGaussianLogLikelihoodKnownSigma
81 changes: 59 additions & 22 deletions pybop/experimental/jax_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ class BaseJaxCost(BaseCost):

def __init__(self, problem: BaseProblem):
super().__init__(problem)

if isinstance(self.problem.model.solver, IDAKLUSolver):
self.problem.model.jaxify_solver(t_eval=self.problem.domain_data)
self.model = self.problem.model
self.n_data = self.problem.n_data
if isinstance(self.model.solver, IDAKLUSolver):
self.model.jaxify_solver(t_eval=self.problem.domain_data)

def __call__(
self,
Expand All @@ -43,15 +44,16 @@ def __call__(
The Sum of Squared Error.
"""
inputs = self.parameters.verify(inputs)
self._update_solver_sensitivities(calculate_grad)
if calculate_grad != self.model.calculate_sensitivities:
self._update_solver_sensitivities(calculate_grad)

if calculate_grad:
y, dy = jax.value_and_grad(self.evaluate)(inputs)
return y, np.asarray(
list(dy.values())
) # Convert grad to numpy for optimisers
else:
return self.evaluate(inputs)
return np.asarray(self.evaluate(inputs))

def _update_solver_sensitivities(self, calculate_grad: bool) -> None:
"""
Expand All @@ -60,18 +62,26 @@ def _update_solver_sensitivities(self, calculate_grad: bool) -> None:
Args:
calculate_grad (bool): Whether gradient calculation is required.
"""
model = self.problem.model
if calculate_grad != model.calculate_sensitivities:
model.jaxify_solver(
t_eval=self.problem.domain_data, calculate_sensitivities=calculate_grad
)

self.model.jaxify_solver(
t_eval=self.problem.domain_data, calculate_sensitivities=calculate_grad
)

@staticmethod
def check_sigma0(sigma0):
if not isinstance(sigma0, (int, float)) or sigma0 <= 0:
raise ValueError("sigma0 must be a positive number")
return float(sigma0)

def observed_fisher(self, inputs: Inputs):
"""
Compute the observed fisher information matrix (FIM)
for the given inputs. This is done with the gradient
as the Hessian is not available.
"""
_, grad = self.__call__(inputs, calculate_grad=True)
return jnp.square(grad) / self.n_data

Check warning on line 83 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L82-L83

Added lines #L82 - L83 were not covered by tests


class JaxSumSquaredError(BaseJaxCost):
"""
Expand All @@ -83,9 +93,9 @@ def __init__(self, problem: BaseProblem):

def evaluate(self, inputs):
# Calculate residuals and error
y = self.problem.jax_evaluate(inputs)
r = jnp.asarray([y - self._target[signal] for signal in self.signal])
return jnp.sum(r**2, axis=1).item()
y = self.problem.evaluate(inputs)
r = jnp.asarray([y[s] - self._target[s] for s in self.signal])
return jnp.sum(r**2)


class JaxLogNormalLikelihood(BaseJaxCost, BaseLikelihood):
Expand All @@ -106,9 +116,7 @@ def __init__(self, problem: BaseProblem, sigma0=0.02):
self.sigma = self.check_sigma0(sigma0)
self.sigma2 = jnp.square(self.sigma)
self._offset = 0.5 * self.n_data * jnp.log(2 * jnp.pi)
self._target_as_array = jnp.asarray(
[self._target[signal] for signal in self.signal]
)
self._target_as_array = jnp.asarray([self._target[s] for s in self.signal])
self._log_target_sum = jnp.sum(jnp.log(self._target_as_array))
self._precompute()

Expand All @@ -121,9 +129,38 @@ def evaluate(self, inputs):
"""
Evaluates the log-normal likelihood.
"""
y = self.problem.jax_evaluate(inputs)
e = jnp.log(self._target_as_array) - jnp.log(y)
likelihood = self._constant_term - jnp.sum(jnp.square(e), axis=1) / (
2 * self.sigma2
)
return likelihood.item()
y = self.problem.evaluate(inputs)
e = jnp.asarray([jnp.log(y[s]) - jnp.log(self._target[s]) for s in self.signal])
likelihood = self._constant_term - jnp.sum(jnp.square(e)) / (2 * self.sigma2)
return likelihood


class JaxGaussianLogLikelihoodKnownSigma(BaseJaxCost, BaseLikelihood):
"""
A Jax implementation of the Gaussian Likelihood function.
This function represents the underlining observed data sampled
from a Gaussian distribution with known noise, `sigma0`.
Parameters
-----------
problem: BaseProblem
The problem to fit of type `pybop.BaseProblem`
sigma0: float, optional
The variance in the measured data
"""

def __init__(self, problem: BaseProblem, sigma0=0.02):
super().__init__(problem)
self.sigma = self.check_sigma0(sigma0)
self.sigma2 = jnp.square(self.sigma)
self._offset = -0.5 * self.n_data * jnp.log(2 * jnp.pi * self.sigma2)
self._multip = -1 / (2.0 * self.sigma2)

Check warning on line 157 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L153-L157

Added lines #L153 - L157 were not covered by tests

def evaluate(self, inputs):
"""
Evaluates the log-normal likelihood.
"""
y = self.problem.evaluate(inputs)
e = jnp.asarray([y[s] - self._target[s] for s in self.signal])
likelihood = jnp.sum(self._offset + self._multip * jnp.sum(jnp.square(e)))
return likelihood

Check warning on line 166 in pybop/experimental/jax_costs.py

View check run for this annotation

Codecov / codecov/patch

pybop/experimental/jax_costs.py#L163-L166

Added lines #L163 - L166 were not covered by tests
6 changes: 4 additions & 2 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,15 @@ def jaxify_solver(self, t_eval, calculate_sensitivities=False):
self._IDAKLU_stored = self._solver.copy()
self._solver = self._solver.jaxify(
model=self._built_model,
t_eval=t_eval,
t_eval=[t_eval[0], t_eval[-1]],
t_interp=t_eval,
calculate_sensitivities=calculate_sensitivities,
)
elif isinstance(self._solver, pybamm.IDAKLUJax):
self._solver = self._IDAKLU_stored.jaxify(
model=self._built_model,
t_eval=t_eval,
t_eval=[t_eval[0], t_eval[-1]],
t_interp=t_eval,
calculate_sensitivities=calculate_sensitivities,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion pybop/optimisers/base_optimiser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import warnings
from typing import Optional, Union

import jax.numpy as jnp
import numpy as np
from scipy.optimize import OptimizeResult

from pybop import (
BaseCost,
BaseJaxCost,
BaseLikelihood,
DesignCost,
Inputs,
Expand Down Expand Up @@ -215,7 +217,7 @@ def log_update(self, x=None, x_best=None, cost=None, cost_best=None):

def convert_to_list(array_like):
"""Helper function to convert input to a list, if necessary."""
if isinstance(array_like, (list, tuple, np.ndarray)):
if isinstance(array_like, (list, tuple, np.ndarray, jnp.ndarray)):
return list(array_like)
elif isinstance(array_like, (int, float)):
return [array_like]
Expand Down Expand Up @@ -316,6 +318,7 @@ def __init__(
):
self.x = x
self.cost = cost
self.fisher = None
self.final_cost = (
final_cost if final_cost is not None else self._calculate_final_cost()
)
Expand All @@ -332,6 +335,10 @@ def __init__(
self._validate_parameters()
self.check_physical_viability(self.x)

# Calculate Fisher Information if JAX Likelihood
if isinstance(cost, BaseJaxCost):
self.fisher = cost.observed_fisher(self.x)

Check warning on line 340 in pybop/optimisers/base_optimiser.py

View check run for this annotation

Codecov / codecov/patch

pybop/optimisers/base_optimiser.py#L340

Added line #L340 was not covered by tests

def _calculate_final_cost(self) -> float:
"""
Calculate the final cost using the cost function and optimised parameters.
Expand Down Expand Up @@ -400,6 +407,7 @@ def __str__(self) -> str:
f"OptimisationResult:\n"
f" Initial parameters: {self.x0}\n"
f" Optimised parameters: {self.x}\n"
f" Diagonal Fisher Information entries: {self.fisher}\n"
f" Final cost: {self.final_cost}\n"
f" Optimisation time: {self.time} seconds\n"
f" Number of iterations: {self.n_iterations}\n"
Expand Down
2 changes: 1 addition & 1 deletion pybop/optimisers/base_pints_optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def fun(x):
x=xs,
x_best=self.pints_optimiser.x_best(),
cost=_fs if self.minimising else [-x for x in _fs],
cost_best=fb if self.minimising else -fb,
cost_best=[fb] if self.minimising else [-fb],
)

# Check stopping criteria:
Expand Down
7 changes: 7 additions & 0 deletions pybop/plot/problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np

from pybop import DesignProblem, FittingProblem
Expand Down Expand Up @@ -40,6 +41,12 @@ def quick(problem, problem_inputs: Inputs = None, show=True, **layout_kwargs):
model_output = problem.evaluate(problem_inputs)
target_output = problem.get_target()

# Convert model_output to np if Jax array
if isinstance(model_output[problem.signal[0]], jnp.ndarray):
model_output = {

Check warning on line 46 in pybop/plot/problem.py

View check run for this annotation

Codecov / codecov/patch

pybop/plot/problem.py#L46

Added line #L46 was not covered by tests
signal: np.asarray(model_output[signal]) for signal in problem.signal
}

# Create a plot for each output
figure_list = []
for i in problem.signal:
Expand Down
28 changes: 3 additions & 25 deletions pybop/problems/base_problem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

import jax.numpy as jnp
import numpy as np
from pybamm import IDAKLUSolver

Expand Down Expand Up @@ -73,6 +72,9 @@ def __init__(
self._target = None
self.verbose = False
self.failure_output = np.asarray([np.inf])
self.exception = [
"These parameter values are infeasible."
] # TODO: Update to a utility function and add to it on exception creation
if isinstance(self._model, BaseModel):
self.eis = self.model.eis
self.domain = "Frequency [Hz]" if self.eis else "Time [s]"
Expand Down Expand Up @@ -148,30 +150,6 @@ def evaluateS1(self, inputs: Inputs):
"""
raise NotImplementedError

def jax_evaluate(
self,
inputs: Inputs,
) -> jnp.ndarray:
"""
Evaluate the model with the given parameters and return the signal
with a Jax model and solver.
Parameters
----------
inputs : Inputs
Parameters for evaluation of the model.
Returns
-------
y : jnp.ndarray
The model output y(t) simulated with given inputs.
"""

y = jnp.squeeze(
self.model.solver.get_vars(self.signal)(self.domain_data, inputs)
)
return y

def get_target(self):
"""
Return the target dataset.
Expand Down
Loading

0 comments on commit 8f11efe

Please sign in to comment.