Skip to content

Commit

Permalink
initial structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Vytautas Jancauskas committed Aug 8, 2024
1 parent 81bf62e commit 096dad6
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/geo_veritas/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

class MNISTDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = PATH_DATASETS,
batch_size: int = BATCH_SIZE,
num_workers: int = NUM_WORKERS,
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers

self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

self.dims = (1, 28, 28)
self.num_classes = 10

def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

def train_dataloader(self):
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
30 changes: 30 additions & 0 deletions src/geo_veritas/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity
116 changes: 116 additions & 0 deletions src/geo_veritas/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

class GAN(L.LightningModule):
def __init__(
self,
channels,
width,
height,
latent_dim: int = 100,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = BATCH_SIZE,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False

# networks
data_shape = (channels, width, height)
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
self.discriminator = Discriminator(img_shape=data_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

def forward(self, z):
return self.generator(z)

def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)

def training_step(self, batch):
imgs, _ = batch

optimizer_g, optimizer_d = self.optimizers()

# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)

# train generator
# generate images
self.toggle_optimizer(optimizer_g)
self.generated_imgs = self(z)

# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, 0)

# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)

def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2

opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []

def on_validation_epoch_end(self):
z = self.validation_z.type_as(self.generator.model[0].weight)

# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
37 changes: 37 additions & 0 deletions src/geo_veritas/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh(),
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img

0 comments on commit 096dad6

Please sign in to comment.