Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Vytautas Jancauskas committed Aug 8, 2024
1 parent 096dad6 commit 997e55e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/geo_veritas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .datamodule import MNISTDataModule
from .generator import Generator
from .discriminator import Discriminator
from .gan import GAN
13 changes: 12 additions & 1 deletion src/geo_veritas/__main__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from geo_veritas import MNISTDataModule
from geo_veritas import GAN
import lightning as L

if __name__ == '__main__':
print("Hello, World!")
dm = MNISTDataModule()
model = GAN(*dm.dims)
trainer = L.Trainer(
accelerator="auto",
devices=1,
max_epochs=5,
)
trainer.fit(model, dm)
7 changes: 5 additions & 2 deletions src/geo_veritas/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from geo_veritas import Generator, Discriminator

BATCH_SIZE = 256 if torch.cuda.is_available() else 64

class GAN(L.LightningModule):
def __init__(
Expand Down Expand Up @@ -59,7 +62,7 @@ def training_step(self, batch):
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, 0)
#self.logger.experiment.add_image("generated_images", grid, 0)

# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
Expand Down Expand Up @@ -113,4 +116,4 @@ def on_validation_epoch_end(self):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
#self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

0 comments on commit 997e55e

Please sign in to comment.