Skip to content

Commit

Permalink
Implement DCGAN as pytorch lightning module
Browse files Browse the repository at this point in the history
  • Loading branch information
alimoezzi committed Aug 13, 2022
1 parent 2c28da1 commit dd4f76e
Showing 1 changed file with 148 additions and 16 deletions.
164 changes: 148 additions & 16 deletions pokemon_dcgan_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
"outputs": [],
"source": [
"%%script false --no-raise-error\n",
"%%script false --no- raise -error\n",
"!pip3 install jupyter==1.0.0\n",
"!pip3 install numpy==1.23.1\n",
"!pip3 install pandas==1.4.3\n",
Expand Down Expand Up @@ -87,9 +87,8 @@
"import random\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.parallel\n",
"import torch.backends.cudnn as cudnn\n",
"import torch.optim as optim\n",
"from torch.functional import F\n",
"import pytorch_lightning as pl\n",
"import torch.utils.data\n",
"import torchvision.datasets as dset\n",
"import torchvision.transforms as transforms\n",
Expand Down Expand Up @@ -155,9 +154,9 @@
" # Size of feature maps in generator\n",
" generator_hidden_size=64,\n",
" # Generator network dimensions\n",
" gen_dims = [8, 4, 2, 1],\n",
" gen_dims=[8, 4, 2, 1],\n",
" # Discriminator network dimensions\n",
" dis_dims = [1, 2, 4, 8],\n",
" dis_dims=[1, 2, 4, 8],\n",
" # Size of feature maps in discriminator\n",
" discriminator_hidden_size=64,\n",
" # Number of training epochs\n",
Expand All @@ -166,11 +165,12 @@
" lr=0.0002,\n",
" # Beta1 hyperparam for Adam optimizers\n",
" beta1=0.5,\n",
" beta2=0.999,\n",
" # Number of GPUs available. Use 0 for CPU mode.\n",
" ngpu=1,\n",
")\n",
"\n",
"device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and config['ngpu'] > 0 ) else \"cpu\")\n"
"device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and config['ngpu'] > 0) else \"cpu\")\n"
]
},
{
Expand All @@ -194,7 +194,7 @@
},
"outputs": [],
"source": [
"%%script false --no-raise-error\n",
"%%script false --no- raise -error\n",
"!git clone https://github.com/PokeAPI/sprites.git datasets/pokemon/\n",
"!mkdir -p {config['dataset_path']}/0\n",
"!mv {config['dataset_path']}/*.png {config['dataset_path']}/0/\n"
Expand All @@ -218,9 +218,11 @@
" color = img[:, :, :3]\n",
" new_img = cv2.bitwise_not(cv2.bitwise_not(color, mask=mask))\n",
" return new_img\n",
"\n",
"\n",
"dataset = dset.ImageFolder(\n",
" root=config['dataset_path'],\n",
" loader = transparent_white_bg_loader,\n",
" loader=transparent_white_bg_loader,\n",
" transform=transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.Resize(config['image_size']),\n",
Expand All @@ -233,7 +235,7 @@
")\n",
"\n",
"dataloader = torch.utils.data.DataLoader(\n",
" dataset, \n",
" dataset,\n",
" batch_size=config['batch_size'],\n",
" shuffle=True,\n",
" num_workers=config['num_workers']\n",
Expand Down Expand Up @@ -482,9 +484,9 @@
"# custom weights initialization called on netG and netD\n",
"def weights_init(m):\n",
" classname = m.__class__.__name__\n",
" if classname.find('Conv') != -1:\n",
" if classname.find('Conv')!=-1:\n",
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
" elif classname.find('BatchNorm') != -1:\n",
" elif classname.find('BatchNorm')!=-1:\n",
" nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
" nn.init.constant_(m.bias.data, 0)"
]
Expand Down Expand Up @@ -530,8 +532,8 @@
" self.hidden_dim*layer,\n",
" self.kernel_size,\n",
" bias=False,\n",
" stride=1 if i == 0 else 2,\n",
" padding=0 if i == 0 else 1,\n",
" stride=1 if i==0 else 2,\n",
" padding=0 if i==0 else 1,\n",
" ),\n",
" nn.BatchNorm2d(self.hidden_dim*layer),\n",
" nn.LeakyReLU(negative_slope=negative_slope),\n",
Expand All @@ -549,7 +551,7 @@
" nn.Tanh()\n",
" ]\n",
" return nn.Sequential(*layers)\n",
" \n",
"\n",
" def forward(self, input):\n",
" return self.features(input)\n"
]
Expand Down Expand Up @@ -759,7 +761,137 @@
}
],
"source": [
"summary(D, (1, config['num_channel'], config['image_size'], config['image_size']), col_names=[\"input_size\", \"output_size\", \"kernel_size\", \"num_params\"])"
"summary(D, (1, config['num_channel'], config['image_size'], config['image_size']),\n",
" col_names=[\"input_size\", \"output_size\", \"kernel_size\", \"num_params\"])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## DCGAN"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"class DCGAN(pl.LightningModule):\n",
" def __int__(\n",
" self,\n",
" batch_size=128,\n",
" image_size=64,\n",
" num_channel=3,\n",
" z_size=100,\n",
" generator_hidden_size=64,\n",
" gen_dims=(8, 4, 2, 1),\n",
" dis_dims=(1, 2, 4, 8),\n",
" discriminator_hidden_size=64,\n",
" lr=0.0002,\n",
" beta1=0.5,\n",
" beta2=0.999,\n",
" **kwargs\n",
" ):\n",
" super().__init()\n",
" self.save_hyperparameters()\n",
" self.G = Generator(\n",
" z_size,\n",
" generator_hidden_size,\n",
" num_channel,\n",
" gen_dims,\n",
" )\n",
" G.apply(weights_init)\n",
" self.D = Discriminator(\n",
" num_channel,\n",
" discriminator_hidden_size,\n",
" 1,\n",
" dis_dims,\n",
" )\n",
" D.apply(weights_init)\n",
"\n",
" self.validation_z = torch.empty(\n",
" 1,\n",
" z_size,\n",
" 1,\n",
" 1,\n",
" device=config['device']\n",
" ).uniform_(-1, 1)\n",
" def forward(self, z):\n",
" return self.G(z)\n",
"\n",
" def loss_function(self, y_hat, y):\n",
" return F.binary_cross_entropy(y_hat, y)\n",
"\n",
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
" images, _ = batch\n",
"\n",
" # input noise\n",
" z = torch.empty(\n",
" self.hparams.batch_size,\n",
" self.hparams.z_size,\n",
" 1,\n",
" 1,\n",
" device=config['device']\n",
" ).uniform_(-1, 1)\n",
"\n",
"\n",
" # ground truth with label smoothing\n",
" real_labels = torch.ones(self.hparams.batch_size, 1, device=config['device']) * 0.9\n",
"\n",
" # train generator\n",
" if optimizer_idx == 0:\n",
" self.generated_images = self(z)\n",
"\n",
" # calculate loss for generator\n",
" g_loss = self.loss_function(self.D(self.generated_images), real_labels)\n",
" self.log('g_loss', g_loss, prog_bar=True)\n",
" return g_loss\n",
"\n",
" # train discriminator\n",
" elif optimizer_idx == 1:\n",
" real_loss = self.loss_function(self.D(images), real_labels)\n",
"\n",
" fake_labels = torch.zeros(self.hparams.batch_size, 1)\n",
" fake_loss = self.loss_function(self.D(self(z).detach()), fake_labels)\n",
"\n",
" # discriminator loss is the average loss of fake and real samples\n",
" d_loss = (real_loss + fake_loss)/2\n",
" self.log('d_loss', d_loss, prog_bar=True)\n",
" return d_loss\n",
"\n",
" def configure_optimizers(self):\n",
" g_opt = torch.optim.Adam(\n",
" self.G.parameters(),\n",
" lr=self.hparams.lr,\n",
" betas=(self.hparams.beta1, self.hparams.beta2)\n",
" )\n",
"\n",
" d_opt = torch.optim.Adam(\n",
" self.D.parameters(),\n",
" lr=self.hparams.lr,\n",
" betas=(self.hparams.beta1, self.hparams.beta2)\n",
" )\n",
"\n",
" return [g_opt, d_opt], []\n",
"\n",
" def on_validation_epoch_end(self):\n",
" # log sampled images\n",
" sample_imgs = self(self.validation_z)\n",
" grid = vutils.make_grid(sample_imgs)\n",
" self.logger.experiment.add_image(\"generated_images\", grid, self.current_epoch)\n",
"\n"
],
"metadata": {
"collapsed": false,
Expand Down

0 comments on commit dd4f76e

Please sign in to comment.