Skip to content

Commit

Permalink
feat #70: saving the plots in disk so we can see them
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 20, 2024
1 parent 618447a commit c973080
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions src/MNIST.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datetime
import os
from dataclasses import dataclass
from typing import Optional, Dict
from typing import Dict, Optional


@dataclass
Expand Down Expand Up @@ -50,6 +50,9 @@ def __init_path_params(self):
# Path where we store training / test data
self.data_path = os.path.join(self.base_path, "data")

# Path where we can store figures
self.plots_path = os.path.join(self.base_path, "plots")

# Dir with all cached models
# This cached models can be loaded from disk when training is skipped
self.model_cache_folder = os.path.join(self.base_path, "cached_models")
Expand Down Expand Up @@ -201,11 +204,9 @@ def dict(self) -> Dict:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets

# For using pre-trained ResNets
import torchvision.models as models
import torchvision.transforms as transforms

# All concrete pieces we're using form sklearn
from sklearn.metrics import accuracy_score, roc_auc_score, silhouette_score
from torch.utils.data import DataLoader, Dataset
Expand All @@ -227,35 +228,18 @@ def dict(self) -> Dict:
import wandb
from lib.data_augmentation import AugmentatedDataset, LazyAugmentatedDataset
from lib.embedding_to_classifier import EmbeddingToClassifier
from lib.loss_functions import (
AddSmallEmbeddingPenalization,
BatchAllTripletLoss,
BatchHardTripletLoss,
MeanTripletBatchTripletLoss,
)
from lib.loss_functions import (AddSmallEmbeddingPenalization,
BatchAllTripletLoss, BatchHardTripletLoss,
MeanTripletBatchTripletLoss)
from lib.models import *
from lib.models import (
CACDResnet18,
CACDResnet50,
FGLigthModel,
LFWLightModel,
LFWResNet18,
NormalizedNet,
ResNet18,
RetrievalAdapter,
)
from lib.models import (CACDResnet18, CACDResnet50, FGLigthModel,
LFWLightModel, LFWResNet18, NormalizedNet, ResNet18,
RetrievalAdapter)
from lib.sampler import CustomSampler
from lib.train_loggers import (
CompoundLogger,
InterClusterLogger,
IntraClusterLogger,
LocalRankAtKLogger,
RankAtKLogger,
SilentLogger,
TrainLogger,
TripletLoggerOffline,
TripletLoggerOnline,
)
from lib.train_loggers import (CompoundLogger, InterClusterLogger,
IntraClusterLogger, LocalRankAtKLogger,
RankAtKLogger, SilentLogger, TrainLogger,
TripletLoggerOffline, TripletLoggerOnline)
from lib.trainers import train_model_online
from lib.visualizations import *

Expand Down Expand Up @@ -316,20 +300,20 @@ def try_to_clean_memory():
# ==============================================================================


# TODO -- values from ADAM's script
mean, std = 0.1307, 0.3081

# TODO -- ADAM -- those base paths must change
print("=> Downloading the MNIST dataset")
train_dataset = torchvision.datasets.MNIST(
"../data/MNIST",
GLOBALS.data_path,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((mean,), (std,))]
),
)
test_dataset = torchvision.datasets.MNIST(
"../data/MNIST",
GLOBALS.data_path,
train=False,
download=True,
transform=transforms.Compose(
Expand Down Expand Up @@ -430,7 +414,7 @@ def try_to_clean_memory():
]


def plot_embeddings(embeddings, targets, xlim=None, ylim=None):
def plot_embeddings(embeddings, targets, title: str, xlim=None, ylim=None):
plt.figure(figsize=(10, 10))
for i in range(10):
inds = np.where(targets == i)[0]
Expand All @@ -442,6 +426,11 @@ def plot_embeddings(embeddings, targets, xlim=None, ylim=None):
if ylim:
plt.ylim(ylim[0], ylim[1])
plt.legend(mnist_classes)
try:
plt.savefig(os.path.join(GLOBALS.plots_path, title))
except Exception as e:
print("Could not save figure in disk")
print(f"Reason was: {e=}")


def extract_embeddings(dataloader, model):
Expand All @@ -462,9 +451,9 @@ def extract_embeddings(dataloader, model):


train_embeddings_otl, train_labels_otl = extract_embeddings(online_train_loader, net)
plot_embeddings(train_embeddings_otl, train_labels_otl)
plot_embeddings(train_embeddings_otl, train_labels_otl, title="Train embeddings")
val_embeddings_otl, val_labels_otl = extract_embeddings(online_test_loader, net)
plot_embeddings(val_embeddings_otl, val_labels_otl)
plot_embeddings(val_embeddings_otl, val_labels_otl, title="Validation embeddings")

# TODO -- ADAM -- use our loggers in the training
# ## Defining the loggers we want to use
Expand Down

0 comments on commit c973080

Please sign in to comment.