-
Notifications
You must be signed in to change notification settings - Fork 267
/
distribution.py
107 lines (85 loc) · 4.37 KB
/
distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Custom PyTorch distribution classes to be registered in policy_util.py
# Mainly used by policy_util action distribution
from torch import distributions
import torch
import torch.nn.functional as F
class Argmax(distributions.Categorical):
'''
Special distribution class for argmax sampling, where probability is always 1 for the argmax.
NOTE although argmax is not a sampling distribution, this implementation is for API consistency.
'''
def __init__(self, probs=None, logits=None, validate_args=None):
if probs is not None:
new_probs = torch.zeros_like(probs, dtype=torch.float)
new_probs[probs == probs.max(dim=-1, keepdim=True)[0]] = 1.0
probs = new_probs
elif logits is not None:
new_logits = torch.full_like(logits, -1e8, dtype=torch.float)
new_logits[logits == logits.max(dim=-1, keepdim=True)[0]] = 1.0
logits = new_logits
super().__init__(probs=probs, logits=logits, validate_args=validate_args)
class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
'''
A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
Explanation http://amid.fish/assets/gumbel.html
NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
Papers:
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
'''
def sample(self, sample_shape=torch.Size()):
'''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
noisy_logits = self.logits - torch.log(-torch.log(u))
return torch.argmax(noisy_logits, dim=-1)
def rsample(self, sample_shape=torch.Size()):
'''
Gumbel-softmax resampling using the Straight-Through trick.
Credit to Ian Temple for bringing this to our attention. To see standalone code of how this works, refer to https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
'''
rout = super().rsample(sample_shape) # differentiable
out = F.one_hot(torch.argmax(rout, dim=-1), self.logits.shape[-1]).float()
return (out - rout).detach() + rout
def log_prob(self, value):
'''value is one-hot or relaxed'''
if value.shape != self.logits.shape:
value = F.one_hot(value.long(), self.logits.shape[-1]).float()
assert value.shape == self.logits.shape
return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)
class MultiCategorical(distributions.Categorical):
'''MultiCategorical as collection of Categoricals'''
def __init__(self, probs=None, logits=None, validate_args=None):
self.categoricals = []
if probs is None:
probs = [None] * len(logits)
elif logits is None:
logits = [None] * len(probs)
else:
raise ValueError('Either probs or logits must be None')
for sub_probs, sub_logits in zip(probs, logits):
categorical = distributions.Categorical(probs=sub_probs, logits=sub_logits, validate_args=validate_args)
self.categoricals.append(categorical)
@property
def logits(self):
return [cat.logits for cat in self.categoricals]
@property
def probs(self):
return [cat.probs for cat in self.categoricals]
@property
def param_shape(self):
return [cat.param_shape for cat in self.categoricals]
@property
def mean(self):
return torch.stack([cat.mean for cat in self.categoricals])
@property
def variance(self):
return torch.stack([cat.variance for cat in self.categoricals])
def sample(self, sample_shape=torch.Size()):
return torch.stack([cat.sample(sample_shape=sample_shape) for cat in self.categoricals])
def log_prob(self, value):
value_t = value.transpose(0, 1)
return torch.stack([cat.log_prob(value_t[idx]) for idx, cat in enumerate(self.categoricals)])
def entropy(self):
return torch.stack([cat.entropy() for cat in self.categoricals])
def enumerate_support(self):
return [cat.enumerate_support() for cat in self.categoricals]