Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variational inference with PyMC #1306

Merged
merged 42 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9eaf164
variational inference prototype
arrjon Feb 14, 2024
abb17a0
variational inference fit
arrjon Feb 14, 2024
7a4ba7a
added kwargs
arrjon Feb 14, 2024
df3922f
remove variational from sample
arrjon Mar 12, 2024
66bb105
save pymc object
arrjon Mar 12, 2024
6b8e9ba
save in sample result
arrjon Mar 12, 2024
2e62e0e
add time to result
arrjon Mar 12, 2024
b08f5a7
make pymc object accessible
arrjon Mar 12, 2024
b4ad7e3
save as McmcPtResult
arrjon Mar 12, 2024
d3de43f
save as McmcPtResult
arrjon Mar 12, 2024
0bc8fa8
Merge remote-tracking branch 'origin/variational_inference' into vari…
arrjon Mar 12, 2024
f580145
fix dim of trace
arrjon Mar 12, 2024
e851977
fix dim of trace
arrjon Mar 12, 2024
4fee2e9
Merge remote-tracking branch 'origin/variational_inference' into vari…
arrjon Mar 12, 2024
023a4e6
fix dim of trace
arrjon Mar 12, 2024
18d5115
add variational parameters to result
arrjon Mar 12, 2024
ff89a94
fix import
arrjon Mar 12, 2024
b89f0a8
save results
arrjon Mar 12, 2024
6b983ce
Merge branch 'develop' into variational_inference
arrjon Mar 14, 2024
3eee507
Merge branch 'develop' into variational_inference
arrjon Mar 14, 2024
1cfc03d
Merge branch 'develop' into variational_inference
arrjon Mar 15, 2024
7cd991d
Merge branch 'develop' into variational_inference
arrjon Mar 15, 2024
2d925e4
Merge branch 'develop' into variational_inference
arrjon Mar 21, 2024
b97c80b
Merge branch 'develop' into variational_inference
arrjon Apr 4, 2024
fef7cd3
tests added
arrjon Apr 5, 2024
fc97b00
Merge branch 'develop' into variational_inference
arrjon Apr 5, 2024
a08e6da
Merge branch 'develop' into variational_inference
arrjon Apr 8, 2024
417148f
run pre-commit hooks
arrjon Apr 8, 2024
8bcbca8
Merge branch 'develop' into variational_inference
arrjon Apr 8, 2024
87a611f
Merge branch 'develop' into variational_inference
arrjon Apr 11, 2024
8378236
add docstrings
arrjon Apr 12, 2024
db46ba0
add warning in write_result()
arrjon Apr 12, 2024
4251931
Merge remote-tracking branch 'origin/variational_inference' into vari…
arrjon Apr 12, 2024
c8f6a2a
Merge branch 'develop' into variational_inference
arrjon Apr 12, 2024
2aed748
Merge branch 'develop' into variational_inference
arrjon Apr 17, 2024
169b808
Merge branch 'develop' into variational_inference
arrjon Apr 18, 2024
2ebad53
Merge branch 'develop' into variational_inference
arrjon Apr 23, 2024
2b0c462
Merge branch 'develop' into variational_inference
arrjon May 6, 2024
7e7f0db
Merge branch 'develop' into variational_inference
arrjon May 15, 2024
98f7258
remove variational problem
arrjon May 15, 2024
a9efc14
Merge branch 'develop' into variational_inference
arrjon May 16, 2024
bdd8873
Merge branch 'develop' into variational_inference
vwiela May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pypesto/variational/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Variational inference
======

Find the best variational approximation in a given family to a distribution from which we can sample.
"""

from .pymc import PymcVariational
from .variational_inference import variational_fit
190 changes: 190 additions & 0 deletions pypesto/variational/pymc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import logging
arrjon marked this conversation as resolved.
Show resolved Hide resolved
from typing import Optional

import numpy as np
import pytensor.tensor as pt
from scipy import stats

from ..objective import FD
from ..result import McmcPtResult
from ..sample.pymc import PymcObjectiveOp, PymcSampler
from ..sample.sampler import SamplerImportError

logger = logging.getLogger(__name__)


# implementation based on the pymc sampler code in pypesto and:
# https://www.pymc.io/projects/examples/en/latest/variational_inference/variational_api_quickstart.html


class PymcVariational(PymcSampler):
"""Wrapper around Pymc v4 variational inference.

Parameters
----------
step_function:
A pymc step function, e.g. NUTS, Slice. If not specified, pymc
determines one automatically (preferable).
**kwargs:
Options are directly passed on to `pymc.fit`.
"""

def fit(
self,
n_iterations: int,
method: str = "advi",
random_seed: Optional[int] = None,
start_sigma: Optional = None,
inf_kwargs: Optional = None,
beta: float = 1.0,
**kwargs,
):
"""
Sample the problem.

Parameters
----------
n_iterations:
Number of iterations.
method: str or :class:`Inference` of pymc
string name is case-insensitive in:
- 'advi' for ADVI
- 'fullrank_advi' for FullRankADVI
- 'svgd' for Stein Variational Gradient Descent
- 'asvgd' for Amortized Stein Variational Gradient Descent
random_seed: int
random seed for reproducibility
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
inf_kwargs: dict
additional kwargs passed to pymc.Inference
beta:
Inverse temperature (e.g. in parallel tempering).
"""
try:
import pymc
except ImportError:
raise SamplerImportError("pymc") from None

problem = self.problem
if not problem.objective.has_grad:
logger.info(
"The objective function does not provide gradients. "
"Finite differences will be used."
)
problem.objective = FD(obj=problem.objective)
log_post = PymcObjectiveOp.create_instance(problem.objective, beta)

x0 = None
x_names_free = problem.get_reduced_vector(problem.x_names)
if self.x0 is not None:
x0 = {
x_name: val
for x_name, val in zip(problem.x_names, self.x0)
if x_name in x_names_free
}

# create model context
with pymc.Model():
# parameter bounds as uniform prior
_k = [
pymc.Uniform(x_name, lower=lb, upper=ub)
for x_name, lb, ub in zip(
x_names_free,
problem.lb,
problem.ub,
)
]

# convert parameters to PyTensor tensor variable
theta = pt.as_tensor_variable(_k)

# define distribution with log-posterior as density
pymc.Potential("potential", log_post(theta))

# record function values
pymc.Deterministic("loggyposty", log_post(theta))

# perform the actual sampling
data = pymc.fit(
n=int(n_iterations),
method=method,
random_seed=random_seed,
start=x0,
start_sigma=start_sigma,
inf_kwargs=inf_kwargs,
**kwargs,
)
arrjon marked this conversation as resolved.
Show resolved Hide resolved

self.data = data

def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult:
"""
Sample from the variational approximation and return McmcPtResult object.

Parameters
----------
n_samples:
Number of samples to be computed.
"""
# get InferenceData object
pymc_data = self.data.sample(n_samples)
x_names_free = self.problem.get_reduced_vector(self.problem.x_names)
post_samples = np.concatenate(
[pymc_data.posterior[name].values for name in x_names_free]
).T
return McmcPtResult(
trace_x=post_samples[np.newaxis, :],
trace_neglogpost=pymc_data.posterior.loggyposty.values,
trace_neglogprior=np.full(
pymc_data.posterior.loggyposty.values.shape, np.nan
),
betas=np.array([1.0] * post_samples.shape[0]),
burn_in=0,
auto_correlation=0,
effective_sample_size=n_samples,
message="variational inference results",
)

def get_variational_parameters(self) -> (list, list):
"""Get the internal pymc variational parameters."""
return (
[param.name for param in self.data.params],
[param.eval() for param in self.data.params],
)

def set_variational_parameters(self, param_list: list):
"""
Set the internal pymc variational parameters.

Parameters
----------
param_list:
List of tuples of the form (param_name, param_value).
"""
for i, param in enumerate(param_list):
self.data.params[i].set_value(param)
arrjon marked this conversation as resolved.
Show resolved Hide resolved

def eval_variational_log_density(self, x: np.ndarray) -> np.ndarray:
"""
Evaluate the log density of the variational approximation at x_points.

Parameters
----------
x:
The points at which to evaluate the log density.
"""
# TODO: add support for other methods
logger.warning(
"currently only supports the methods `advi` and `fullrank_advi`"
)

if x.ndim == 1:
x = x.reshape(1, -1)
log_density_at_points = np.zeros_like(x)
for i, point in enumerate(x):
log_density_at_points[i] = stats.multivariate_normal.logpdf(
point, mean=self.data.mean.eval(), cov=self.data.cov.eval()
)
vi_log_density = np.sum(log_density_at_points, axis=-1)
return vi_log_density
134 changes: 134 additions & 0 deletions pypesto/variational/variational_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
from time import process_time
from typing import Callable, List, Optional, Union

import numpy as np

from ..problem import Problem
from ..result import Result
from ..sample.util import bound_n_samples_from_env
from ..store import autosave
from .pymc import PymcVariational

logger = logging.getLogger(__name__)


def variational_fit(
problem: Problem,
n_iterations: int,
method: str = "advi",
n_samples: Optional[int] = None,
random_seed: Optional[int] = None,
start_sigma: Optional[dict[str, np.ndarray]] = None,
x0: Union[np.ndarray, List[np.ndarray]] = None,
result: Result = None,
filename: Union[str, Callable, None] = None,
overwrite: bool = False,
**kwargs,
) -> Result:
"""
Call to do parameter sampling.

Parameters
----------
problem:
The problem to be solved. If None is provided, a
:class:`pypesto.AdaptiveMetropolisSampler` is used.
n_iterations:
Number of iterations for the optimization.
method: str or :class:`Inference` of pymc (only interface currently supported)
string name is case-insensitive in:
- 'advi' for ADVI
- 'fullrank_advi' for FullRankADVI
- 'svgd' for Stein Variational Gradient Descent
- 'asvgd' for Amortized Stein Variational Gradient Descent
n_samples:
Number of samples to generate after optimization.
random_seed: int
random seed for reproducibility
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
x0:
Initial parameter for the variational optimization. If None, the best parameter
found in optimization is used.
result:
A result to write to. If None provided, one is created from the
problem.
filename:
Name of the hdf5 file, where the result will be saved. Default is
None, which deactivates automatic saving. If set to
"Auto" it will automatically generate a file named
`year_month_day_profiling_result.hdf5`.
Optionally a method, see docs for `pypesto.store.auto.autosave`.
overwrite:
Whether to overwrite `result/sampling` in the autosave file
if it already exists.

Returns
-------
result:
A result with filled in sample_options part.
"""
# prepare result object
if result is None:
result = Result(problem)

# number of samples
if n_iterations is not None:
n_iterations = bound_n_samples_from_env(n_iterations)

# try to find initial parameters
if x0 is None:
result.optimize_result.sort()
if len(result.optimize_result.list) > 0:
x0 = problem.get_reduced_vector(
result.optimize_result.list[0]["x"]
)

# set variational inference
# currently we only support pymc
variational = PymcVariational()

# initialize sampler to problem
variational.initialize(problem=problem, x0=x0)

# perform the sampling and track time
t_start = process_time()
variational.fit(
n_iterations=n_iterations,
method=method,
random_seed=random_seed,
start_sigma=start_sigma,
**kwargs,
)
t_elapsed = process_time() - t_start
logger.info("Elapsed time: " + str(t_elapsed))

# extract results and save samples to pypesto result
if n_samples is None or n_samples == 0:
# constructing a McmcPtResult object with nearly empty trace_x
n_samples = 1

result.sample_result = variational.sample(n_samples)
result.sample_result.time = t_elapsed
arrjon marked this conversation as resolved.
Show resolved Hide resolved

autosave(
filename=filename,
result=result,
store_type="sample",
overwrite=overwrite,
)

# make pymc object available in result
# TODO: if needed, we can add a result object for variational inference methods
result.variational_result = variational
arrjon marked this conversation as resolved.
Show resolved Hide resolved
(
result.sample_result.variational_parameters_names,
result.sample_result.variational_parameters,
) = variational.get_variational_parameters()
if filename is not None:
logger.warning(
"Variational parameters are not saved in the hdf5 file. You have to save them manually."
)
arrjon marked this conversation as resolved.
Show resolved Hide resolved

return result
1 change: 1 addition & 0 deletions test/variational/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Variational inference tests."""
Loading
Loading