-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
64 lines (44 loc) · 1.65 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch.nn as nn
import torch.nn.functional as func
import torch.distributions as dists
from torch import log as tlog, sum as tsum
class Encoder(nn.Module):
# in_dim = dimension of input
# out_dim = dimension of latent space
# Prob nice preprocessing step: standardize values to N(0,1) so reconstruction easier
def __init__(self, in_dim, out_dim):
super().__init__()
nn_hidden = 45
self.out = out_dim
self.latentize = nn.Sequential(
nn.Linear(in_dim, nn_hidden),
nn.LeakyReLU(),
nn.Linear(nn_hidden, nn_hidden),
nn.LeakyReLU()
)
self.latent_mean = nn.Sequential(
nn.Linear(nn_hidden, out_dim),
# nn.Tanh()
)
# Assuming diagonal variances
self.latent_var = nn.Sequential(
nn.Linear(nn_hidden, out_dim),
nn.Sigmoid() # var should be non-neg
)
self.normal = dists.Normal(0,1)
self.kl = 0
def forward(self, x):
# print(x.shape)
# n, d = x.shape
# print(f"n: {n}")
n = x.shape[0]
latents = self.latentize(x)
mu = self.latent_mean(latents)
sigma = self.latent_var(latents)
random_point = self.normal.sample(sample_shape=(n,self.out))
z = mu + sigma*random_point
# self.kl += (sigma**2 + mu**2 - tlog(sigma) - 1/2).sum()
# self.kl += (1 - sigma**2 - mu**2 + tlog(sigma)**2).sum()
# self.kl += -0.5*(1 - sigma - mu**2 + tlog(sigma)**2).sum()
self.kl += - 0.5 * tsum(1+ tlog(sigma) - mu**2 - sigma)
return z