-
Notifications
You must be signed in to change notification settings - Fork 35
/
vqgan.py
56 lines (45 loc) · 2.25 KB
/
vqgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from encoder import Encoder
from decoder import Decoder
from codebook import Codebook
class VQGAN(nn.Module):
def __init__(self, args):
super(VQGAN, self).__init__()
self.encoder = Encoder(args).to(device=args.device)
self.decoder = Decoder(args).to(device=args.device)
self.codebook = Codebook(args).to(device=args.device)
self.quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)
self.post_quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)
def forward(self, imgs):
encoded_images = self.encoder(imgs)
quantized_encoded_images = self.quant_conv(encoded_images)
codebook_mapping, codebook_indices, q_loss = self.codebook(quantized_encoded_images)
quantized_codebook_mapping = self.post_quant_conv(codebook_mapping)
decoded_images = self.decoder(quantized_codebook_mapping)
return decoded_images, codebook_indices, q_loss
def encode(self, x):
encoded_images = self.encoder(x)
quantized_encoded_images = self.quant_conv(encoded_images)
codebook_mapping, codebook_indices, q_loss = self.codebook(quantized_encoded_images)
return codebook_mapping, codebook_indices, q_loss
def decode(self, z):
quantized_codebook_mapping = self.post_quant_conv(z)
decoded_images = self.decoder(quantized_codebook_mapping)
return decoded_images
def calculate_lambda(self, nll_loss, g_loss):
last_layer = self.decoder.model[-1]
last_layer_weight = last_layer.weight
nll_grads = torch.autograd.grad(nll_loss, last_layer_weight, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer_weight, retain_graph=True)[0]
λ = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
λ = torch.clamp(λ, 0, 1e4).detach()
return 0.8 * λ
@staticmethod
def adopt_weight(disc_factor, i, threshold, value=0.):
if i < threshold:
disc_factor = value
return disc_factor
def load_checkpoint(self, path):
self.load_state_dict(torch.load(path))
print("Loaded Checkpoint for VQGAN....")