-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Variational inference with PyMC (#1306)
* variational inference fit * remove variational from sample * make pymc object accessible * save as McmcPtResult * tests added * add warning in write_result()
- Loading branch information
Showing
6 changed files
with
420 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
Variational inference | ||
====== | ||
Find the best variational approximation in a given family to a distribution from which we can sample. | ||
""" | ||
|
||
from .pymc import PymcVariational | ||
from .variational_inference import variational_fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
"""Pymc v4 Sampler for Variational Inference.""" | ||
|
||
import logging | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pytensor.tensor as pt | ||
from scipy import stats | ||
|
||
from ..objective import FD | ||
from ..result import McmcPtResult | ||
from ..sample.pymc import PymcObjectiveOp, PymcSampler | ||
from ..sample.sampler import SamplerImportError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# implementation based on the pymc sampler code in pypesto and: | ||
# https://www.pymc.io/projects/examples/en/latest/variational_inference/variational_api_quickstart.html | ||
|
||
|
||
class PymcVariational(PymcSampler): | ||
"""Wrapper around Pymc v4 variational inference. | ||
Parameters | ||
---------- | ||
step_function: | ||
A pymc step function, e.g. NUTS, Slice. If not specified, pymc | ||
determines one automatically (preferable). | ||
**kwargs: | ||
Options are directly passed on to `pymc.fit`. | ||
""" | ||
|
||
def fit( | ||
self, | ||
n_iterations: int, | ||
method: str = "advi", | ||
random_seed: Optional[int] = None, | ||
start_sigma: Optional = None, | ||
inf_kwargs: Optional = None, | ||
beta: float = 1.0, | ||
**kwargs, | ||
): | ||
""" | ||
Sample the problem. | ||
Parameters | ||
---------- | ||
n_iterations: | ||
Number of iterations. | ||
method: str or :class:`Inference` of pymc | ||
string name is case-insensitive in: | ||
- 'advi' for ADVI | ||
- 'fullrank_advi' for FullRankADVI | ||
- 'svgd' for Stein Variational Gradient Descent | ||
- 'asvgd' for Amortized Stein Variational Gradient Descent | ||
random_seed: int | ||
random seed for reproducibility | ||
start_sigma: `dict[str, np.ndarray]` | ||
starting standard deviation for inference, only available for method 'advi' | ||
inf_kwargs: dict | ||
additional kwargs passed to pymc.Inference | ||
beta: | ||
Inverse temperature (e.g. in parallel tempering). | ||
""" | ||
try: | ||
import pymc | ||
except ImportError: | ||
raise SamplerImportError("pymc") from None | ||
|
||
problem = self.problem | ||
if not problem.objective.has_grad: | ||
logger.info( | ||
"The objective function does not provide gradients. " | ||
"Finite differences will be used." | ||
) | ||
problem.objective = FD(obj=problem.objective) | ||
log_post = PymcObjectiveOp.create_instance(problem.objective, beta) | ||
|
||
x0 = None | ||
x_names_free = problem.get_reduced_vector(problem.x_names) | ||
if self.x0 is not None: | ||
x0 = { | ||
x_name: val | ||
for x_name, val in zip(problem.x_names, self.x0) | ||
if x_name in x_names_free | ||
} | ||
|
||
# create model context | ||
with pymc.Model(): | ||
# parameter bounds as uniform prior | ||
_k = [ | ||
pymc.Uniform(x_name, lower=lb, upper=ub) | ||
for x_name, lb, ub in zip( | ||
x_names_free, | ||
problem.lb, | ||
problem.ub, | ||
) | ||
] | ||
|
||
# convert parameters to PyTensor tensor variable | ||
theta = pt.as_tensor_variable(_k) | ||
|
||
# define distribution with log-posterior as density | ||
pymc.Potential("potential", log_post(theta)) | ||
|
||
# record function values | ||
pymc.Deterministic("loggyposty", log_post(theta)) | ||
|
||
# perform the actual sampling | ||
data = pymc.fit( | ||
n=int(n_iterations), | ||
method=method, | ||
random_seed=random_seed, | ||
start=x0, | ||
start_sigma=start_sigma, | ||
inf_kwargs=inf_kwargs, | ||
**kwargs, | ||
) | ||
|
||
self.data = data | ||
|
||
def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: | ||
""" | ||
Sample from the variational approximation and return McmcPtResult object. | ||
Parameters | ||
---------- | ||
n_samples: | ||
Number of samples to be computed. | ||
""" | ||
# get InferenceData object | ||
pymc_data = self.data.sample(n_samples) | ||
x_names_free = self.problem.get_reduced_vector(self.problem.x_names) | ||
post_samples = np.concatenate( | ||
[pymc_data.posterior[name].values for name in x_names_free] | ||
).T | ||
return McmcPtResult( | ||
trace_x=post_samples[np.newaxis, :], | ||
trace_neglogpost=pymc_data.posterior.loggyposty.values, | ||
trace_neglogprior=np.full( | ||
pymc_data.posterior.loggyposty.values.shape, np.nan | ||
), | ||
betas=np.array([1.0] * post_samples.shape[0]), | ||
burn_in=0, | ||
auto_correlation=0, | ||
effective_sample_size=n_samples, | ||
message="variational inference results", | ||
) | ||
|
||
def get_variational_parameters(self) -> (list, list): | ||
"""Get the internal pymc variational parameters.""" | ||
return ( | ||
[param.name for param in self.data.params], | ||
[param.eval() for param in self.data.params], | ||
) | ||
|
||
def set_variational_parameters(self, param_list: list): | ||
""" | ||
Set the internal pymc variational parameters. | ||
Parameters | ||
---------- | ||
param_list: | ||
List of tuples of the form (param_name, param_value). | ||
""" | ||
if len(param_list) != len(self.data.params): | ||
raise ValueError( | ||
"The number of parameters does not match the number of variational parameters." | ||
) | ||
for i, param in enumerate(param_list): | ||
self.data.params[i].set_value(param) | ||
|
||
def eval_variational_log_density(self, x: np.ndarray) -> np.ndarray: | ||
""" | ||
Evaluate the log density of the variational approximation at x_points. | ||
Parameters | ||
---------- | ||
x: | ||
The points at which to evaluate the log density. | ||
""" | ||
# TODO: add support for other methods | ||
logger.warning( | ||
"currently only supports the methods `advi` and `fullrank_advi`" | ||
) | ||
|
||
if x.ndim == 1: | ||
x = x.reshape(1, -1) | ||
log_density_at_points = np.zeros_like(x) | ||
for i, point in enumerate(x): | ||
log_density_at_points[i] = stats.multivariate_normal.logpdf( | ||
point, mean=self.data.mean.eval(), cov=self.data.cov.eval() | ||
) | ||
vi_log_density = np.sum(log_density_at_points, axis=-1) | ||
return vi_log_density |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Functions for variational inference accessible to the user. Currently only pymc is supported.""" | ||
|
||
import logging | ||
from time import process_time | ||
from typing import Callable, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from ..problem import Problem | ||
from ..result import Result | ||
from ..sample.util import bound_n_samples_from_env | ||
from ..store import autosave | ||
from .pymc import PymcVariational | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def variational_fit( | ||
problem: Problem, | ||
n_iterations: int, | ||
method: str = "advi", | ||
n_samples: Optional[int] = None, | ||
random_seed: Optional[int] = None, | ||
start_sigma: Optional[dict[str, np.ndarray]] = None, | ||
x0: Union[np.ndarray, List[np.ndarray]] = None, | ||
result: Result = None, | ||
filename: Union[str, Callable, None] = None, | ||
overwrite: bool = False, | ||
**kwargs, | ||
) -> Result: | ||
""" | ||
Call to do parameter sampling. | ||
Parameters | ||
---------- | ||
problem: | ||
The problem to be solved. If None is provided, a | ||
:class:`pypesto.AdaptiveMetropolisSampler` is used. | ||
n_iterations: | ||
Number of iterations for the optimization. | ||
method: str or :class:`Inference` of pymc (only interface currently supported) | ||
string name is case-insensitive in: | ||
- 'advi' for ADVI | ||
- 'fullrank_advi' for FullRankADVI | ||
- 'svgd' for Stein Variational Gradient Descent | ||
- 'asvgd' for Amortized Stein Variational Gradient Descent | ||
n_samples: | ||
Number of samples to generate after optimization. | ||
random_seed: int | ||
random seed for reproducibility | ||
start_sigma: `dict[str, np.ndarray]` | ||
starting standard deviation for inference, only available for method 'advi' | ||
x0: | ||
Initial parameter for the variational optimization. If None, the best parameter | ||
found in optimization is used. | ||
result: | ||
A result to write to. If None provided, one is created from the | ||
problem. | ||
filename: | ||
Name of the hdf5 file, where the result will be saved. Default is | ||
None, which deactivates automatic saving. If set to | ||
"Auto" it will automatically generate a file named | ||
`year_month_day_profiling_result.hdf5`. | ||
Optionally a method, see docs for `pypesto.store.auto.autosave`. | ||
overwrite: | ||
Whether to overwrite `result/sampling` in the autosave file | ||
if it already exists. | ||
Returns | ||
------- | ||
result: | ||
A result with filled in sample_options part. | ||
""" | ||
# prepare result object | ||
if result is None: | ||
result = Result(problem) | ||
|
||
# number of samples | ||
if n_iterations is not None: | ||
n_iterations = bound_n_samples_from_env(n_iterations) | ||
|
||
# try to find initial parameters | ||
if x0 is None: | ||
result.optimize_result.sort() | ||
if len(result.optimize_result.list) > 0: | ||
x0 = problem.get_reduced_vector( | ||
result.optimize_result.list[0]["x"] | ||
) | ||
|
||
# set variational inference | ||
# currently we only support pymc | ||
variational = PymcVariational() | ||
|
||
# initialize sampler to problem | ||
variational.initialize(problem=problem, x0=x0) | ||
|
||
# perform the sampling and track time | ||
t_start = process_time() | ||
variational.fit( | ||
n_iterations=n_iterations, | ||
method=method, | ||
random_seed=random_seed, | ||
start_sigma=start_sigma, | ||
**kwargs, | ||
) | ||
t_elapsed = process_time() - t_start | ||
logger.info("Elapsed time: " + str(t_elapsed)) | ||
|
||
# extract results and save samples to pypesto result | ||
if n_samples is None or n_samples == 0: | ||
# constructing a McmcPtResult object with nearly empty trace_x | ||
n_samples = 1 | ||
|
||
result.sample_result = variational.sample(n_samples) | ||
result.sample_result.time = t_elapsed | ||
|
||
autosave( | ||
filename=filename, | ||
result=result, | ||
store_type="sample", | ||
overwrite=overwrite, | ||
) | ||
|
||
# make pymc object available in result | ||
# TODO: if needed, we can add a result object for variational inference methods | ||
result.variational_result = variational | ||
( | ||
result.sample_result.variational_parameters_names, | ||
result.sample_result.variational_parameters, | ||
) = variational.get_variational_parameters() | ||
if filename is not None: | ||
logger.warning( | ||
"Variational parameters are not saved in the hdf5 file. You have to save them manually." | ||
) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Variational inference tests.""" |
Oops, something went wrong.