Skip to content

Commit

Permalink
moves x0 checks, construction to BaseSampler, updates tests, fixes tr…
Browse files Browse the repository at this point in the history
…ansformations in samplers
  • Loading branch information
BradyPlanden committed Oct 23, 2024
1 parent 44a03f2 commit 6754c53
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 41 deletions.
31 changes: 8 additions & 23 deletions pybop/samplers/base_pints_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
cov0: Initial standard deviation for the chains.
kwargs: Additional keyword arguments.
"""
super().__init__(log_pdf, x0, cov0)
super().__init__(log_pdf, x0, chains, cov0)

# Set kwargs
self._max_iterations = kwargs.get("max_iterations", 500)
Expand All @@ -61,11 +61,6 @@ def __init__(
self._iteration = 0
self._loop_iters = 0
self._warm_up = warm_up
self.n_parameters = (
self._log_pdf[0].n_parameters
if isinstance(self._log_pdf, list)
else self._log_pdf.n_parameters
)

# Check log_pdf
if isinstance(self._log_pdf, BaseCost):
Expand All @@ -85,17 +80,6 @@ def __init__(

self._multi_log_pdf = True

# Number of chains
self._n_chains = chains
if self._n_chains < 1:
raise ValueError("Number of chains must be greater than 0")

# Check initial conditions
if self._x0.size != self.n_parameters:
raise ValueError("x0 must have the same number of parameters as log_pdf")
if len(self._x0) != self._n_chains or len(self._x0) == 1:
self._x0 = np.tile(self._x0, (self._n_chains, 1))

# Single chain vs multiple chain samplers
self._single_chain = issubclass(self.sampler, SingleChainMCMC)

Expand Down Expand Up @@ -283,13 +267,14 @@ def _check_stopping_criteria(self):
raise ValueError("At least one stopping criterion must be set.")

def _create_evaluator(self):
f = self._log_pdf
# Check for sensitivities from sampler and set evaluator
common_args = {"apply_transform": True}

if self._needs_sensitivities:
if not self._multi_log_pdf:
f = partial(f, calculate_grad=True)
else:
f = [partial(pdf, calculate_grad=True) for pdf in f]
common_args["calculate_grad"] = True
if not self._multi_log_pdf:
f = partial(self._log_pdf, **common_args)
else:
f = [partial(pdf, **common_args) for pdf in self._log_pdf]

if self._parallel:
if not self._multi_log_pdf:
Expand Down
40 changes: 30 additions & 10 deletions pybop/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,50 @@ class BaseSampler:
Base class for Monte Carlo samplers.
"""

def __init__(self, log_pdf: LogPosterior, x0, cov0: Union[np.ndarray, float]):
def __init__(
self, log_pdf: LogPosterior, x0, chains: int, cov0: Union[np.ndarray, float]
):
"""
Initialise the base sampler.
Parameters
----------------
log_pdf (pybop.LogPosterior or List[pybop.LogPosterior]): The posterior or PDF to be sampled.
chains (int): Number of chains to be used.
x0: List-like initial condition for Monte Carlo sampling.
cov0: The covariance matrix to be sampled.
"""
self._log_pdf = log_pdf
self._cov0 = cov0

# Number of chains
self._n_chains = chains
if self._n_chains < 1:
raise ValueError("Number of chains must be greater than 0")

# Set up parameters based on log_pdf
self.parameters = (
log_pdf.parameters if isinstance(log_pdf, LogPosterior) else Parameters()
)
if isinstance(log_pdf, LogPosterior):
self.parameters = log_pdf.parameters
self.n_parameters = log_pdf.n_parameters
elif isinstance(log_pdf, (list, np.ndarray)) and isinstance(
log_pdf[0], LogPosterior
):
self.parameters = log_pdf[0].parameters
self.n_parameters = log_pdf[0].n_parameters
else:
self.parameters = Parameters()
self.n_parameters = 0

# Initialize x0
self._x0 = (
self.parameters.initial_value()
if x0 is None
else np.asarray([x0], dtype=float)
)
# Check initial conditions
if x0 is not None and len(x0) != self.n_parameters:
raise ValueError("x0 must have the same number of parameters as log_pdf")

# Set initial values, if x0 is None, initial values are unmodified.
self.parameters.update(initial_values=x0 if x0 is not None else None)
self._x0 = self.parameters.reset_initial_value(apply_transform=True)

if len(self._x0) != self._n_chains or len(self._x0) == 1:
self._x0 = np.tile(self._x0, (self._n_chains, 1))

def run(self) -> np.ndarray:
"""
Expand Down
20 changes: 13 additions & 7 deletions tests/integration/test_monte_carlo_thevenin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup(self):
self.sigma0 = 1e-3
self.ground_truth = np.clip(
np.asarray([0.05, 0.05]) + np.random.normal(loc=0.0, scale=0.01, size=2),
a_min=0.0,
a_min=1e-4,
a_max=0.1,
)
self.fast_samplers = [
Expand Down Expand Up @@ -57,10 +57,14 @@ def model(self):
def parameters(self):
return pybop.Parameters(
pybop.Parameter(
"R0 [Ohm]", prior=pybop.Uniform(1e-2, 8e-2), bounds=[1e-2, 8e-2]
"R0 [Ohm]",
prior=pybop.Uniform(1e-2, 8e-2),
bounds=[1e-4, 1e-1],
),
pybop.Parameter(
"R1 [Ohm]", prior=pybop.Uniform(1e-2, 8e-2), bounds=[1e-2, 8e-2]
"R1 [Ohm]",
prior=pybop.Uniform(1e-2, 8e-2),
bounds=[1e-4, 1e-1],
),
)

Expand Down Expand Up @@ -123,8 +127,8 @@ def test_sampling_thevenin(self, sampler, posterior, map_estimate):
common_args = {
"log_pdf": posterior,
"chains": 1,
"warm_up": 450,
"cov0": [1e-3, 1e-3] if sampler in self.fast_samplers else [0.1, 0.1],
"warm_up": 550,
"cov0": [1e-3, 1e-3],
"max_iterations": 1000,
"x0": x0,
}
Expand All @@ -140,7 +144,9 @@ def test_sampling_thevenin(self, sampler, posterior, map_estimate):
ess = summary.effective_sample_size()
np.testing.assert_array_less(0, ess)
if not isinstance(sampler, RelativisticMCMC):
np.testing.assert_array_less(summary.rhat(), 1.05)
np.testing.assert_array_less(
summary.rhat(), 1.2
) # Large rhat, to enable faster tests

# Assert both final sample and posterior mean
x = np.mean(chains, axis=1)
Expand All @@ -152,7 +158,7 @@ def get_data(self, model, init_soc):
initial_state = {"Initial SoC": init_soc}
experiment = pybop.Experiment(
[
("Discharge at 0.5C for 4 minutes (8 second period)",),
("Discharge at 0.5C for 6 minutes (20 second period)",),
]
)
sim = model.predict(initial_state=initial_state, experiment=experiment)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_set_parallel(self, log_posterior, x0, chains):

@pytest.mark.unit
def test_base_sampler(self, log_posterior, x0):
sampler = pybop.BaseSampler(log_posterior, x0, cov0=0.1)
sampler = pybop.BaseSampler(log_posterior, x0, chains=1, cov0=0.1)
with pytest.raises(NotImplementedError):
sampler.run()

Expand Down

0 comments on commit 6754c53

Please sign in to comment.