-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple stick breaking (Formerly #3620) #3638
Changes from all commits
83ab40a
731e6b5
61ba98e
6b78e24
8157760
8cba7e1
3a2dca2
6c556f4
c0cfbd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import warnings | ||
|
||
import theano | ||
import theano.tensor as tt | ||
|
||
|
@@ -14,6 +16,7 @@ | |
__all__ = [ | ||
"transform", | ||
"stick_breaking", | ||
"stick_breaking2", | ||
"logodds", | ||
"interval", | ||
"log_exp_m1", | ||
|
@@ -510,6 +513,62 @@ def t_stick_breaking(eps): | |
return StickBreaking(eps) | ||
|
||
|
||
class StickBreaking2(Transform): | ||
""" | ||
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values. | ||
""" | ||
|
||
name = "stickbreaking" | ||
|
||
def __init__(self, eps=None): | ||
if eps is not None: | ||
warnings.warn("The argument `eps` is depricated and will not be used.", | ||
DeprecationWarning) | ||
Comment on lines
+523
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this because I replaced the original stick-breaking that took the argument and didn't want to break it. Maybe we don't need it anymore. |
||
|
||
def forward(self, x_): | ||
x = x_.T | ||
n = x.shape[0] | ||
lx = tt.log(x) | ||
shift = tt.sum(lx, 0, keepdims=True) / n | ||
y = lx[:-1] - shift | ||
return floatX(y.T) | ||
|
||
def forward_val(self, x_): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably have to keep the argument |
||
x = x_.T | ||
n = x.shape[0] | ||
lx = np.log(x) | ||
shift = np.sum(lx, 0, keepdims=True) / n | ||
y = lx[:-1] - shift | ||
return floatX(y.T) | ||
|
||
def backward(self, y_): | ||
y = y_.T | ||
y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)]) | ||
# "softmax" with vector support and no deprication warning: | ||
e_y = tt.exp(y - tt.max(y, 0, keepdims=True)) | ||
x = e_y / tt.sum(e_y, 0, keepdims=True) | ||
return floatX(x.T) | ||
|
||
def backward_val(self, y_): | ||
y = y_.T | ||
y = np.concatenate([y, -np.sum(y, 0, keepdims=True)]) | ||
x = np.exp(y)/np.sum(np.exp(y), 0, keepdims=True) | ||
return floatX(x.T) | ||
|
||
def jacobian_det(self, y_): | ||
y = y_.T | ||
Km1 = y.shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to write |
||
sy = tt.sum(y, 0, keepdims=True) | ||
r = tt.concatenate([y+sy, tt.zeros(sy.shape)]) | ||
# stable according to: http://deeplearning.net/software/theano_versions/0.9.X/NEWS.html | ||
sr = tt.log(tt.sum(tt.exp(r), 0, keepdims=True)) | ||
d = tt.log(Km1) + (Km1*sy) - (Km1*sr) | ||
return tt.sum(d, 0).T | ||
|
||
|
||
stick_breaking2 = StickBreaking2() | ||
|
||
|
||
class Circular(ElemwiseTransform): | ||
"""Transforms a linear space into a circular one. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deprecated