-
Notifications
You must be signed in to change notification settings - Fork 918
/
black_box_svi.py
84 lines (65 loc) · 2.97 KB
/
black_box_svi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.multivariate_normal as mvn
import autograd.scipy.stats.norm as norm
from autograd import grad
from autograd.misc.optimizers import adam
def black_box_variational_inference(logprob, D, num_samples):
"""Implements http://arxiv.org/abs/1401.0118, and uses the
local reparameterization trick from http://arxiv.org/abs/1506.02557"""
def unpack_params(params):
# Variational dist is a diagonal Gaussian.
mean, log_std = params[:D], params[D:]
return mean, log_std
def gaussian_entropy(log_std):
return 0.5 * D * (1.0 + np.log(2 * np.pi)) + np.sum(log_std)
rs = npr.RandomState(0)
def variational_objective(params, t):
"""Provides a stochastic estimate of the variational lower bound."""
mean, log_std = unpack_params(params)
samples = rs.randn(num_samples, D) * np.exp(log_std) + mean
lower_bound = gaussian_entropy(log_std) + np.mean(logprob(samples, t))
return -lower_bound
gradient = grad(variational_objective)
return variational_objective, gradient, unpack_params
if __name__ == "__main__":
# Specify an inference problem by its unnormalized log-density.
D = 2
def log_density(x, t):
mu, log_sigma = x[:, 0], x[:, 1]
sigma_density = norm.logpdf(log_sigma, 0, 1.35)
mu_density = norm.logpdf(mu, 0, np.exp(log_sigma))
return sigma_density + mu_density
# Build variational objective.
objective, gradient, unpack_params = black_box_variational_inference(log_density, D, num_samples=2000)
# Set up plotting code
def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numticks=101):
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
zs = func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T)
Z = zs.reshape(X.shape)
plt.contour(X, Y, Z)
ax.set_yticks([])
ax.set_xticks([])
# Set up figure.
fig = plt.figure(figsize=(8, 8), facecolor="white")
ax = fig.add_subplot(111, frameon=False)
plt.ion()
plt.show(block=False)
def callback(params, t, g):
print(f"Iteration {t} lower bound {-objective(params, t)}")
plt.cla()
target_distribution = lambda x: np.exp(log_density(x, t))
plot_isocontours(ax, target_distribution)
mean, log_std = unpack_params(params)
variational_contour = lambda x: mvn.pdf(x, mean, np.diag(np.exp(2 * log_std)))
plot_isocontours(ax, variational_contour)
plt.draw()
plt.pause(1.0 / 30.0)
print("Optimizing variational parameters...")
init_mean = -1 * np.ones(D)
init_log_std = -5 * np.ones(D)
init_var_params = np.concatenate([init_mean, init_log_std])
variational_params = adam(gradient, init_var_params, step_size=0.1, num_iters=2000, callback=callback)