diff --git a/mnist_autoencoder.ipynb b/mnist_autoencoder.ipynb index 524a9b4..269bb70 100644 --- a/mnist_autoencoder.ipynb +++ b/mnist_autoencoder.ipynb @@ -218,23 +218,23 @@ " self.decoder_lin = nn.Sequential(\n", " nn.Linear(latent_dims, 128),\n", " nn.ReLU(),\n", - " nn.Linear(128, 3 * 3 * 32),\n", + " nn.Linear(128, 3 * 3 * 32), # [ batch_size, 288 ]\n", " nn.ReLU()\n", " )\n", "\n", " self.decoder_cnn = nn.Sequential(\n", - " nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),\n", + " nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0), # [ batch_size, 16, 7, 7]\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", - " nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),\n", + " nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), # [ batch_size, 8, 14, 14]\n", " nn.BatchNorm2d(8),\n", " nn.ReLU(),\n", - " nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)\n", + " nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1) # [ batch_size, 1, 28, 28]\n", " )\n", "\n", " def forward(self, x):\n", " x = self.decoder_lin(x)\n", - " x = self.decoder_cnn(x.unflatten(dim=1, sizes=(32, 3, 3)))\n", + " x = self.decoder_cnn(x.unflatten(dim=1, sizes=(32, 3, 3))) # [ batch_size, 288 ] -> [ batch_size, 32, 3, 3] -> [ batch_size, 1, 28, 28]\n", " return torch.sigmoid(x)" ], "metadata": {