Skip to content

Commit

Permalink
Implement Generator
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Jul 16, 2022
1 parent 928c530 commit 69b4e74
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions mnist_gan_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,67 @@
" for layer in disc_dims:\n",
" layers += [\n",
" nn.Linear(in_channels, self.hidden_dim * layer),\n",
" nn.LeakyReLU(inplace=True, negative_slope=negative_slope),\n",
" nn.Dropout(p=dropout_probability, inplace=True),\n",
" nn.LeakyReLU(inplace=True, negative_slope=negative_slope)\n",
" ]\n",
" in_channels = self.hidden_dim * layer\n",
" return nn.Sequential(*layers)"
" return nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" x = self.features(x)\n",
" return self.classifier(x)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Generator"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"gen_dims = [4, 2, 1]\n",
"class Generator(nn.Module):\n",
" def __init__(self, input_size, hidden_dim, output_size):\n",
" super(Generator, self).__init__()\n",
" self.hidden_dim = hidden_dim\n",
" self.output_size = output_size\n",
" self.input_size = input_size\n",
" self.features = self._make_layers()\n",
"\n",
" def _make_layers(self, negative_slope=0.2, dropout_probability=0.3):\n",
" layers = []\n",
" in_channels = self.input_size\n",
" for layer in disc_dims:\n",
" layers += [\n",
" nn.Linear(in_channels, self.hidden_dim * layer),\n",
" nn.LeakyReLU(inplace=True, negative_slope=negative_slope),\n",
" nn.Dropout(p=dropout_probability, inplace=True),\n",
" ]\n",
" in_channels = self.hidden_dim * layer\n",
" layers += [\n",
" nn.Linear(self.hidden_dim, self.output_size),\n",
" nn.Tanh()\n",
" ]\n",
" return nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" return self.features(x)"
],
"metadata": {
"collapsed": false,
Expand Down

0 comments on commit 69b4e74

Please sign in to comment.