Skip to content

Commit

Permalink
Comment sizing of cnn decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Aug 3, 2021
1 parent 9e1f4e5 commit 415169a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mnist_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,23 @@
" self.decoder_lin = nn.Sequential(\n",
" nn.Linear(latent_dims, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, 3 * 3 * 32),\n",
" nn.Linear(128, 3 * 3 * 32), # [ batch_size, 288 ]\n",
" nn.ReLU()\n",
" )\n",
"\n",
" self.decoder_cnn = nn.Sequential(\n",
" nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),\n",
" nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0), # [ batch_size, 16, 7, 7]\n",
" nn.BatchNorm2d(16),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),\n",
" nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), # [ batch_size, 8, 14, 14]\n",
" nn.BatchNorm2d(8),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)\n",
" nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1) # [ batch_size, 1, 28, 28]\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.decoder_lin(x)\n",
" x = self.decoder_cnn(x.unflatten(dim=1, sizes=(32, 3, 3)))\n",
" x = self.decoder_cnn(x.unflatten(dim=1, sizes=(32, 3, 3))) # [ batch_size, 288 ] -> [ batch_size, 32, 3, 3] -> [ batch_size, 1, 28, 28]\n",
" return torch.sigmoid(x)"
],
"metadata": {
Expand Down

0 comments on commit 415169a

Please sign in to comment.