generated from DLR-MF-DAS/general-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Vytautas Jancauskas
committed
Aug 8, 2024
1 parent
81bf62e
commit 096dad6
Showing
4 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision.datasets import MNIST | ||
|
||
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") | ||
BATCH_SIZE = 256 if torch.cuda.is_available() else 64 | ||
NUM_WORKERS = int(os.cpu_count() / 2) | ||
|
||
class MNISTDataModule(L.LightningDataModule): | ||
def __init__( | ||
self, | ||
data_dir: str = PATH_DATASETS, | ||
batch_size: int = BATCH_SIZE, | ||
num_workers: int = NUM_WORKERS, | ||
): | ||
super().__init__() | ||
self.data_dir = data_dir | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
|
||
self.transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)), | ||
] | ||
) | ||
|
||
self.dims = (1, 28, 28) | ||
self.num_classes = 10 | ||
|
||
def prepare_data(self): | ||
# download | ||
MNIST(self.data_dir, train=True, download=True) | ||
MNIST(self.data_dir, train=False, download=True) | ||
|
||
def setup(self, stage=None): | ||
# Assign train/val datasets for use in dataloaders | ||
if stage == "fit" or stage is None: | ||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) | ||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | ||
|
||
# Assign test dataset for use in dataloader(s) | ||
if stage == "test" or stage is None: | ||
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) | ||
|
||
def train_dataloader(self): | ||
return DataLoader( | ||
self.mnist_train, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision.datasets import MNIST | ||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, img_shape): | ||
super().__init__() | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(np.prod(img_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, img): | ||
img_flat = img.view(img.size(0), -1) | ||
validity = self.model(img_flat) | ||
|
||
return validity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision.datasets import MNIST | ||
|
||
class GAN(L.LightningModule): | ||
def __init__( | ||
self, | ||
channels, | ||
width, | ||
height, | ||
latent_dim: int = 100, | ||
lr: float = 0.0002, | ||
b1: float = 0.5, | ||
b2: float = 0.999, | ||
batch_size: int = BATCH_SIZE, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.automatic_optimization = False | ||
|
||
# networks | ||
data_shape = (channels, width, height) | ||
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape) | ||
self.discriminator = Discriminator(img_shape=data_shape) | ||
|
||
self.validation_z = torch.randn(8, self.hparams.latent_dim) | ||
|
||
self.example_input_array = torch.zeros(2, self.hparams.latent_dim) | ||
|
||
def forward(self, z): | ||
return self.generator(z) | ||
|
||
def adversarial_loss(self, y_hat, y): | ||
return F.binary_cross_entropy(y_hat, y) | ||
|
||
def training_step(self, batch): | ||
imgs, _ = batch | ||
|
||
optimizer_g, optimizer_d = self.optimizers() | ||
|
||
# sample noise | ||
z = torch.randn(imgs.shape[0], self.hparams.latent_dim) | ||
z = z.type_as(imgs) | ||
|
||
# train generator | ||
# generate images | ||
self.toggle_optimizer(optimizer_g) | ||
self.generated_imgs = self(z) | ||
|
||
# log sampled images | ||
sample_imgs = self.generated_imgs[:6] | ||
grid = torchvision.utils.make_grid(sample_imgs) | ||
self.logger.experiment.add_image("generated_images", grid, 0) | ||
|
||
# ground truth result (ie: all fake) | ||
# put on GPU because we created this tensor inside training_loop | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
# adversarial loss is binary cross-entropy | ||
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) | ||
self.log("g_loss", g_loss, prog_bar=True) | ||
self.manual_backward(g_loss) | ||
optimizer_g.step() | ||
optimizer_g.zero_grad() | ||
self.untoggle_optimizer(optimizer_g) | ||
|
||
# train discriminator | ||
# Measure discriminator's ability to classify real from generated samples | ||
self.toggle_optimizer(optimizer_d) | ||
|
||
# how well can it label as real? | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
real_loss = self.adversarial_loss(self.discriminator(imgs), valid) | ||
|
||
# how well can it label as fake? | ||
fake = torch.zeros(imgs.size(0), 1) | ||
fake = fake.type_as(imgs) | ||
|
||
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) | ||
|
||
# discriminator loss is the average of these | ||
d_loss = (real_loss + fake_loss) / 2 | ||
self.log("d_loss", d_loss, prog_bar=True) | ||
self.manual_backward(d_loss) | ||
optimizer_d.step() | ||
optimizer_d.zero_grad() | ||
self.untoggle_optimizer(optimizer_d) | ||
|
||
def configure_optimizers(self): | ||
lr = self.hparams.lr | ||
b1 = self.hparams.b1 | ||
b2 = self.hparams.b2 | ||
|
||
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) | ||
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) | ||
return [opt_g, opt_d], [] | ||
|
||
def on_validation_epoch_end(self): | ||
z = self.validation_z.type_as(self.generator.model[0].weight) | ||
|
||
# log sampled images | ||
sample_imgs = self(z) | ||
grid = torchvision.utils.make_grid(sample_imgs) | ||
self.logger.experiment.add_image("generated_images", grid, self.current_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision.datasets import MNIST | ||
|
||
class Generator(nn.Module): | ||
def __init__(self, latent_dim, img_shape): | ||
super().__init__() | ||
self.img_shape = img_shape | ||
|
||
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
self.model = nn.Sequential( | ||
*block(latent_dim, 128, normalize=False), | ||
*block(128, 256), | ||
*block(256, 512), | ||
*block(512, 1024), | ||
nn.Linear(1024, int(np.prod(img_shape))), | ||
nn.Tanh(), | ||
) | ||
|
||
def forward(self, z): | ||
img = self.model(z) | ||
img = img.view(img.size(0), *self.img_shape) | ||
return img |