Skip to content

Commit

Permalink
adds coverage, updates multi_fitting.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Oct 23, 2024
1 parent 6754c53 commit 878b6a4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
19 changes: 13 additions & 6 deletions examples/scripts/multi_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.SPM(parameter_set=parameter_set)

# Create initial SOC, experiment objects
init_soc = [{"Initial SoC": 0.8}, {"Initial SoC": 0.6}]
experiment = [
pybop.Experiment([("Discharge at 0.5C for 2 minutes (4 second period)")]),
pybop.Experiment([("Discharge at 1C for 1 minutes (4 second period)")]),
]

# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
Expand All @@ -21,9 +28,8 @@
)

# Generate a dataset and a fitting problem
sigma = 0.001
experiment = pybop.Experiment([("Discharge at 0.5C for 2 minutes (4 second period)")])
values = model.predict(initial_state={"Initial SoC": 0.8}, experiment=experiment)
sigma = 0.002
values = model.predict(initial_state=init_soc[0], experiment=experiment[0])
dataset_1 = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
Expand All @@ -36,8 +42,7 @@

# Generate a second dataset and problem
model = model.new_copy()
experiment = pybop.Experiment([("Discharge at 1C for 1 minutes (4 second period)")])
values = model.predict(initial_state={"Initial SoC": 0.8}, experiment=experiment)
values = model.predict(initial_state=init_soc[1], experiment=experiment[1])
dataset_2 = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
Expand All @@ -53,9 +58,11 @@

# Generate the cost function and optimisation class
cost = pybop.SumSquaredError(problem)
optim = pybop.IRPropMin(
optim = pybop.CuckooSearch(
cost,
verbose=True,
sigma0=0.05,
max_unchanged_iterations=20,
max_iterations=100,
)

Expand Down
7 changes: 4 additions & 3 deletions pybop/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from pints import ParallelEvaluator

from pybop import LogPosterior, Parameters
from pybop import LogPosterior


class BaseSampler:
Expand Down Expand Up @@ -43,8 +43,9 @@ def __init__(
self.parameters = log_pdf[0].parameters
self.n_parameters = log_pdf[0].n_parameters
else:
self.parameters = Parameters()
self.n_parameters = 0
raise ValueError(
"log_pdf must be a LogPosterior or List[LogPosterior]"
) # TODO: Update for more general sampling

# Check initial conditions
if x0 is not None and len(x0) != self.n_parameters:
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_monte_carlo_thevenin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def map_estimate(self, posterior):
common_args = {
"max_iterations": 100,
"max_unchanged_iterations": 35,
"absolute_tolerance": 1e-7,
"sigma0": [3e-4, 3e-4],
}
optim = pybop.CMAES(posterior, **common_args)
Expand All @@ -123,7 +124,7 @@ def map_estimate(self, posterior):
)
@pytest.mark.integration
def test_sampling_thevenin(self, sampler, posterior, map_estimate):
x0 = np.clip(map_estimate + np.random.normal(0, 1e-3, size=2), 1e-3, 1e-1)
x0 = np.clip(map_estimate + np.random.normal(0, 1e-3, size=2), 1e-4, 1e-1)
common_args = {
"log_pdf": posterior,
"chains": 1,
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import re
from unittest.mock import call, patch

import numpy as np
Expand Down Expand Up @@ -396,6 +397,12 @@ def test_base_sampler(self, log_posterior, x0):
with pytest.raises(NotImplementedError):
sampler.run()

with pytest.raises(
ValueError,
match=re.escape("log_pdf must be a LogPosterior or List[LogPosterior]"),
):
pybop.BaseSampler(pybop.WeightedCost(log_posterior), x0, chains=1, cov0=0.1)

@pytest.mark.unit
def test_MCMC_sampler(self, log_posterior, x0, chains):
with pytest.raises(TypeError):
Expand Down

0 comments on commit 878b6a4

Please sign in to comment.