Skip to content

PyTorch implementation of VQ-VAE (Oord et al., 2017) & PixelCNN ((Oord et al., 2016) and training it on Fashion MNIST and CIFAR-10

Notifications You must be signed in to change notification settings

KimRass/VQ-VAE-PixelCNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

1. Pre-trained Models

1) On Fashion MNIST

  • vqvae-fashion_mnist.pth
  • Trained VQ-VAE for 47 epochs. (Validation loss: 0.145)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 14 epochs. (Validataion loss: 1.279)
    dataset="fashion_mnist"
    batch_size=128
    lr=0.0002
    n_embeds=128
    hidden_dim=256
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2) On CIFAR-10

  • vqvae-cifar10.pth
  • Trained VQ-VAE for 40 epochs. (Validation loss: 0.139)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2
  • Then trained PixelCNN for 96 epochs. (Validataion loss: 2.226)
    dataset="cifar10"
    batch_size=128
    lr=0.0003
    n_embeds=128
    hidden_dim=64
    n_pixelcnn_res_blocks=2
    n_pixelcnn_conv_blocks=2

2. Samples

Fashion MNIST
CIFAR-10

3. Implementation Details

1) detach()

  • VQ-VAE 학습에서 Loss 계산 시 z_q = z_e + (z_q - z_e).detach()를 추가할 시 학습이 더 빨라지는 것을 확인했으나, 정확히 어떤 기능을 하는지까지는 알지 못했습니다.

About

PyTorch implementation of VQ-VAE (Oord et al., 2017) & PixelCNN ((Oord et al., 2016) and training it on Fashion MNIST and CIFAR-10

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published