-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added base class for variational methods #1600
Changes from 25 commits
2d6fee8
4e5b9c2
0ebaacd
55c8ce6
9ab04da
9811220
fbd1d5b
208aa79
40d0146
fc0673b
ea82ebd
140a80c
168b113
c1211a6
9690562
34da7c8
07a248a
889b50e
0d486fb
125f6ad
1af91c0
69f07a1
9614bf9
32a2eb7
5e68b95
0f2c38f
63e57d7
4d4cb82
82c7996
2cd6bc5
4d810f2
16a226b
d8e9886
1bb349e
87e7e2d
9eb79a0
be1ca80
e8f6644
4add3bc
43a8638
6a88fde
a3bad35
7ed2cb5
163b1be
fad9410
e1a88e0
ae349e9
9e237ef
63c1285
02f5fa6
8cc9558
4e302e6
e5df6ee
23ed175
7a7cdc3
ac949d2
26adf3b
63000fb
5240260
7802a78
2407d78
fbf26d4
c394d5e
2162d4c
7609f72
d55d258
dca919c
cb2e219
d94e7e7
8d1f088
37843af
3dc6f1b
a16512e
7127c23
23b14ff
96cd5bb
633e4e9
0629adc
06099a2
75a4849
ff325d8
79ac934
57dbe47
f8bce58
244bf21
8d91fee
1a9fa3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ | |
|
||
from .special import gammaln, multigammaln | ||
|
||
c = - 0.5 * np.log(2 * np.pi) | ||
|
||
|
||
def bound(logp, *conditions): | ||
""" | ||
|
@@ -77,3 +79,39 @@ def i1(x): | |
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600, | ||
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3) | ||
+ 14175 / (98304 * x**4))) | ||
|
||
|
||
def sd2rho(sd): | ||
"""sd -> rho | ||
theano converter | ||
mu + sd*e = mu + log(1+exp(rho))*e""" | ||
return tt.log(tt.exp(sd) - 1) | ||
|
||
|
||
def rho2sd(rho): | ||
"""rho -> sd | ||
theano converter | ||
mu + sd*e = mu + log(1+exp(rho))*e""" | ||
return tt.log1p(tt.exp(rho)) | ||
|
||
|
||
def kl_divergence_normal_pair(mu1, mu2, sd1, sd2): | ||
elemwise_kl = (tt.log(sd2/sd1) + | ||
(sd2**2 + (mu1-mu2)**2)/(2.*sd2**2) - | ||
0.5) | ||
return tt.sum(elemwise_kl) | ||
|
||
|
||
def kl_divergence_normal_pair3(mu1, mu2, rho1, rho2): | ||
sd1, sd2 = rho2sd(rho1), rho2sd(rho2) | ||
return kl_divergence_normal_pair(mu1, mu2, sd1, sd2) | ||
|
||
|
||
def log_normal(x, mean, std, eps=0.0): | ||
std += eps | ||
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2) | ||
|
||
|
||
def log_normal3(x, mean, rho, eps=0.0): | ||
std = rho2sd(rho) | ||
return log_normal(x, mean, std, eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does "3" mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this notion was used in other library:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a more informative name for these functions that appending "3" to the name? Perhaps use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think docstring with cross references will be better. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import sys | ||
import theano | ||
import theano.tensor as tt | ||
from theano.tensor import (constant, flatten, zeros_like, ones_like, stack, concatenate, sum, prod, | ||
from theano.tensor import (constant, flatten as _flatten, zeros_like, ones_like, stack, concatenate, sum, prod, | ||
lt, gt, le, ge, eq, neq, switch, clip, where, and_, or_, abs_, exp, log, | ||
cos, sin, tan, cosh, sinh, tanh, sqr, sqrt, erf, erfinv, dot, maximum, | ||
minimum, sgn, ceil, floor) | ||
|
@@ -22,3 +22,10 @@ def invlogit(x, eps=sys.float_info.epsilon): | |
|
||
def logit(p): | ||
return tt.log(p / (1 - p)) | ||
|
||
|
||
def flatten(tensors, outdim=1): | ||
if not isinstance(tensors, list): | ||
return _flatten(tensors, outdim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm using a private method that's the same name as the other one. I'm not so sure about this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually you've done it quite a lot - so this is just a style thing - I don't see a huge problem with it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about it too. It should be better named as |
||
else: | ||
return tt.concatenate([var.ravel() for var in tensors]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not used, I can delete it