-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdistributions.py
41 lines (30 loc) · 1.3 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import torch
import torch.distributions as D
# NOTE: log_prob and entropy sum along the second dimension since the distribution
# has diagonal covariance.
def diagonal_gaussian_log_prob(x, mu, log_sigma, sigma_eps=0):
unreduced_log_prob = -np.log(2*np.pi)/2 - log_sigma - ((x - mu)/(torch.exp(log_sigma) + sigma_eps))**2/2
return unreduced_log_prob.sum(dim=1, keepdim=True)
def diagonal_gaussian_entropy(log_sigma):
unreduced_entropy = np.log(2*np.pi)/2 + 1./2 + log_sigma
return unreduced_entropy.sum(dim=1, keepdim=True)
def diagonal_gaussian_sample(mu, sigma, shape=None):
if shape is None:
eps = torch.randn_like(mu)
else:
eps = torch.randn(shape).to(mu.device)
return mu + sigma*eps
def uniform_sample(template=None, shape=None, device=None, range=[0,1]):
if template is not None:
x = torch.rand_like(template)
elif shape is not None and device is not None:
x = torch.rand(shape).to(device)
else:
raise ValueError("Must provide template or shape and device to uniform_sample")
x = x * range[1] + (1-x) * range[0]
return x
def get_gaussian_mixture(weights, mus, sigmas):
mix = D.Categorical(logits=weights)
comp = D.Independent(D.Normal(mus, sigmas), 1)
return D.MixtureSameFamily(mix, comp)