- vqvae-fashion_mnist.pth
- Trained VQ-VAE for 47 epochs. (Validation loss: 0.145)
dataset="fashion_mnist" batch_size=128 lr=0.0002 n_embeds=128 hidden_dim=256 n_pixelcnn_res_blocks=2 n_pixelcnn_conv_blocks=2
- Then trained PixelCNN for 14 epochs. (Validataion loss: 1.279)
dataset="fashion_mnist" batch_size=128 lr=0.0002 n_embeds=128 hidden_dim=256 n_pixelcnn_res_blocks=2 n_pixelcnn_conv_blocks=2
- vqvae-cifar10.pth
- Trained VQ-VAE for 40 epochs. (Validation loss: 0.139)
dataset="cifar10" batch_size=128 lr=0.0003 n_embeds=128 hidden_dim=64 n_pixelcnn_res_blocks=2 n_pixelcnn_conv_blocks=2
- Then trained PixelCNN for 96 epochs. (Validataion loss: 2.226)
dataset="cifar10" batch_size=128 lr=0.0003 n_embeds=128 hidden_dim=64 n_pixelcnn_res_blocks=2 n_pixelcnn_conv_blocks=2
Fashion MNIST |
---|
CIFAR-10 |
---|
- VQ-VAE 학습에서 Loss 계산 시
z_q = z_e + (z_q - z_e).detach()
를 추가할 시 학습이 더 빨라지는 것을 확인했으나, 정확히 어떤 기능을 하는지까지는 알지 못했습니다.