Skip to content

Commit

Permalink
small cosmetic updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Trenton Bricken committed Jul 3, 2020
1 parent 227857f commit 567e163
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 89 deletions.
110 changes: 24 additions & 86 deletions Figure2Replication.ipynb

Large diffs are not rendered by default.

Binary file modified discrete_flows/__pycache__/disc_utils.cpython-37.pyc
Binary file not shown.
4 changes: 1 addition & 3 deletions discrete_flows/disc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def reverse(self, inputs, **kwargs):
loc, scale = torch.split(layer_outs, self.vocab_size, dim=-1)
loc = disc_utils.one_hot_argmax(loc, self.temperature).type(inputs.dtype)
scale = disc_utils.one_hot_argmax(scale, self.temperature).type(inputs.dtype)
print('the scale', scale.argmax(-1))
#print('the scale', scale.argmax(-1))
inverse_scale = disc_utils.multiplicative_inverse(scale, self.vocab_size)
shifted_inputs = disc_utils.one_hot_minus(z1, loc)
x1 = disc_utils.one_hot_multiply(shifted_inputs, inverse_scale)
Expand All @@ -311,7 +311,6 @@ def reverse(self, inputs, **kwargs):

def forward(self, inputs, **kwargs):
"""Reverse pass for the inverse bipartite transformation. From data to latent. """
#print(inputs.shape)
assert len(inputs.shape) ==2, 'need to flatten the inputs first!'
x0, x1 = inputs[:,:self.dim//2], inputs[:,self.dim//2:]
if self.parity:
Expand All @@ -338,7 +337,6 @@ def forward(self, inputs, **kwargs):
if self.parity:
z0, z1 = z1, z0
z = torch.cat([z0, z1], dim=1)
#print('returned z', z.shape)
return z

def log_det_jacobian(self, inputs):
Expand Down

0 comments on commit 567e163

Please sign in to comment.