From dd4f76ef9027a70e237f5d07b4c3583c0be4c2c4 Mon Sep 17 00:00:00 2001 From: ALi Date: Sat, 13 Aug 2022 19:25:18 +0200 Subject: [PATCH] Implement DCGAN as pytorch lightning module --- pokemon_dcgan_pytorch.ipynb | 164 ++++++++++++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 16 deletions(-) diff --git a/pokemon_dcgan_pytorch.ipynb b/pokemon_dcgan_pytorch.ipynb index fdcd97a..a3da94e 100644 --- a/pokemon_dcgan_pytorch.ipynb +++ b/pokemon_dcgan_pytorch.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "%%script false --no-raise-error\n", + "%%script false --no- raise -error\n", "!pip3 install jupyter==1.0.0\n", "!pip3 install numpy==1.23.1\n", "!pip3 install pandas==1.4.3\n", @@ -87,9 +87,8 @@ "import random\n", "import torch\n", "import torch.nn as nn\n", - "import torch.nn.parallel\n", - "import torch.backends.cudnn as cudnn\n", - "import torch.optim as optim\n", + "from torch.functional import F\n", + "import pytorch_lightning as pl\n", "import torch.utils.data\n", "import torchvision.datasets as dset\n", "import torchvision.transforms as transforms\n", @@ -155,9 +154,9 @@ " # Size of feature maps in generator\n", " generator_hidden_size=64,\n", " # Generator network dimensions\n", - " gen_dims = [8, 4, 2, 1],\n", + " gen_dims=[8, 4, 2, 1],\n", " # Discriminator network dimensions\n", - " dis_dims = [1, 2, 4, 8],\n", + " dis_dims=[1, 2, 4, 8],\n", " # Size of feature maps in discriminator\n", " discriminator_hidden_size=64,\n", " # Number of training epochs\n", @@ -166,11 +165,12 @@ " lr=0.0002,\n", " # Beta1 hyperparam for Adam optimizers\n", " beta1=0.5,\n", + " beta2=0.999,\n", " # Number of GPUs available. Use 0 for CPU mode.\n", " ngpu=1,\n", ")\n", "\n", - "device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and config['ngpu'] > 0 ) else \"cpu\")\n" + "device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and config['ngpu'] > 0) else \"cpu\")\n" ] }, { @@ -194,7 +194,7 @@ }, "outputs": [], "source": [ - "%%script false --no-raise-error\n", + "%%script false --no- raise -error\n", "!git clone https://github.com/PokeAPI/sprites.git datasets/pokemon/\n", "!mkdir -p {config['dataset_path']}/0\n", "!mv {config['dataset_path']}/*.png {config['dataset_path']}/0/\n" @@ -218,9 +218,11 @@ " color = img[:, :, :3]\n", " new_img = cv2.bitwise_not(cv2.bitwise_not(color, mask=mask))\n", " return new_img\n", + "\n", + "\n", "dataset = dset.ImageFolder(\n", " root=config['dataset_path'],\n", - " loader = transparent_white_bg_loader,\n", + " loader=transparent_white_bg_loader,\n", " transform=transforms.Compose([\n", " transforms.ToPILImage(),\n", " transforms.Resize(config['image_size']),\n", @@ -233,7 +235,7 @@ ")\n", "\n", "dataloader = torch.utils.data.DataLoader(\n", - " dataset, \n", + " dataset,\n", " batch_size=config['batch_size'],\n", " shuffle=True,\n", " num_workers=config['num_workers']\n", @@ -482,9 +484,9 @@ "# custom weights initialization called on netG and netD\n", "def weights_init(m):\n", " classname = m.__class__.__name__\n", - " if classname.find('Conv') != -1:\n", + " if classname.find('Conv')!=-1:\n", " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", - " elif classname.find('BatchNorm') != -1:\n", + " elif classname.find('BatchNorm')!=-1:\n", " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", " nn.init.constant_(m.bias.data, 0)" ] @@ -530,8 +532,8 @@ " self.hidden_dim*layer,\n", " self.kernel_size,\n", " bias=False,\n", - " stride=1 if i == 0 else 2,\n", - " padding=0 if i == 0 else 1,\n", + " stride=1 if i==0 else 2,\n", + " padding=0 if i==0 else 1,\n", " ),\n", " nn.BatchNorm2d(self.hidden_dim*layer),\n", " nn.LeakyReLU(negative_slope=negative_slope),\n", @@ -549,7 +551,7 @@ " nn.Tanh()\n", " ]\n", " return nn.Sequential(*layers)\n", - " \n", + "\n", " def forward(self, input):\n", " return self.features(input)\n" ] @@ -759,7 +761,137 @@ } ], "source": [ - "summary(D, (1, config['num_channel'], config['image_size'], config['image_size']), col_names=[\"input_size\", \"output_size\", \"kernel_size\", \"num_params\"])" + "summary(D, (1, config['num_channel'], config['image_size'], config['image_size']),\n", + " col_names=[\"input_size\", \"output_size\", \"kernel_size\", \"num_params\"])" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## DCGAN" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "class DCGAN(pl.LightningModule):\n", + " def __int__(\n", + " self,\n", + " batch_size=128,\n", + " image_size=64,\n", + " num_channel=3,\n", + " z_size=100,\n", + " generator_hidden_size=64,\n", + " gen_dims=(8, 4, 2, 1),\n", + " dis_dims=(1, 2, 4, 8),\n", + " discriminator_hidden_size=64,\n", + " lr=0.0002,\n", + " beta1=0.5,\n", + " beta2=0.999,\n", + " **kwargs\n", + " ):\n", + " super().__init()\n", + " self.save_hyperparameters()\n", + " self.G = Generator(\n", + " z_size,\n", + " generator_hidden_size,\n", + " num_channel,\n", + " gen_dims,\n", + " )\n", + " G.apply(weights_init)\n", + " self.D = Discriminator(\n", + " num_channel,\n", + " discriminator_hidden_size,\n", + " 1,\n", + " dis_dims,\n", + " )\n", + " D.apply(weights_init)\n", + "\n", + " self.validation_z = torch.empty(\n", + " 1,\n", + " z_size,\n", + " 1,\n", + " 1,\n", + " device=config['device']\n", + " ).uniform_(-1, 1)\n", + " def forward(self, z):\n", + " return self.G(z)\n", + "\n", + " def loss_function(self, y_hat, y):\n", + " return F.binary_cross_entropy(y_hat, y)\n", + "\n", + " def training_step(self, batch, batch_idx, optimizer_idx):\n", + " images, _ = batch\n", + "\n", + " # input noise\n", + " z = torch.empty(\n", + " self.hparams.batch_size,\n", + " self.hparams.z_size,\n", + " 1,\n", + " 1,\n", + " device=config['device']\n", + " ).uniform_(-1, 1)\n", + "\n", + "\n", + " # ground truth with label smoothing\n", + " real_labels = torch.ones(self.hparams.batch_size, 1, device=config['device']) * 0.9\n", + "\n", + " # train generator\n", + " if optimizer_idx == 0:\n", + " self.generated_images = self(z)\n", + "\n", + " # calculate loss for generator\n", + " g_loss = self.loss_function(self.D(self.generated_images), real_labels)\n", + " self.log('g_loss', g_loss, prog_bar=True)\n", + " return g_loss\n", + "\n", + " # train discriminator\n", + " elif optimizer_idx == 1:\n", + " real_loss = self.loss_function(self.D(images), real_labels)\n", + "\n", + " fake_labels = torch.zeros(self.hparams.batch_size, 1)\n", + " fake_loss = self.loss_function(self.D(self(z).detach()), fake_labels)\n", + "\n", + " # discriminator loss is the average loss of fake and real samples\n", + " d_loss = (real_loss + fake_loss)/2\n", + " self.log('d_loss', d_loss, prog_bar=True)\n", + " return d_loss\n", + "\n", + " def configure_optimizers(self):\n", + " g_opt = torch.optim.Adam(\n", + " self.G.parameters(),\n", + " lr=self.hparams.lr,\n", + " betas=(self.hparams.beta1, self.hparams.beta2)\n", + " )\n", + "\n", + " d_opt = torch.optim.Adam(\n", + " self.D.parameters(),\n", + " lr=self.hparams.lr,\n", + " betas=(self.hparams.beta1, self.hparams.beta2)\n", + " )\n", + "\n", + " return [g_opt, d_opt], []\n", + "\n", + " def on_validation_epoch_end(self):\n", + " # log sampled images\n", + " sample_imgs = self(self.validation_z)\n", + " grid = vutils.make_grid(sample_imgs)\n", + " self.logger.experiment.add_image(\"generated_images\", grid, self.current_epoch)\n", + "\n" ], "metadata": { "collapsed": false,