From 618447a7e8ded49e8e793597a739658219922d3a Mon Sep 17 00:00:00 2001 From: Sergio Quijano Date: Sat, 20 Apr 2024 20:47:31 +0200 Subject: [PATCH] fix, feat #70: small mistakes done in the prev. refactor --- src/MNIST.py | 56 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/src/MNIST.py b/src/MNIST.py index 303c6fc..5abdd21 100644 --- a/src/MNIST.py +++ b/src/MNIST.py @@ -7,7 +7,7 @@ import datetime import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict @dataclass @@ -144,6 +144,15 @@ def __init_wandb_params(self): self.wandb_project_name = "MNIST dataset" self.wandb_run_name = str(datetime.datetime.now()) + def dict(self) -> Dict: + """ + Wandb need a dictionary representation of the class for saving values in + their system. So this method tries its best to create a dict repr for + this data class + """ + + return self.__dict__ + GLOBALS = GlobalParameters() @@ -160,8 +169,8 @@ def __init_wandb_params(self): # TODO -- remove this sys.path.append(GLOBALS.base_path) -sys.path.append(GLOBALS.base_path, "src") -sys.path.append(GLOBALS.base_path, "src/lib") +sys.path.append(os.path.join(GLOBALS.base_path, "src")) +sys.path.append(os.path.join(GLOBALS.base_path, "src/lib")) # Importing the modules we are going to use @@ -192,9 +201,11 @@ def __init_wandb_params(self): import torch.optim as optim import torchvision import torchvision.datasets as datasets + # For using pre-trained ResNets import torchvision.models as models import torchvision.transforms as transforms + # All concrete pieces we're using form sklearn from sklearn.metrics import accuracy_score, roc_auc_score, silhouette_score from torch.utils.data import DataLoader, Dataset @@ -216,18 +227,35 @@ def __init_wandb_params(self): import wandb from lib.data_augmentation import AugmentatedDataset, LazyAugmentatedDataset from lib.embedding_to_classifier import EmbeddingToClassifier -from lib.loss_functions import (AddSmallEmbeddingPenalization, - BatchAllTripletLoss, BatchHardTripletLoss, - MeanTripletBatchTripletLoss) +from lib.loss_functions import ( + AddSmallEmbeddingPenalization, + BatchAllTripletLoss, + BatchHardTripletLoss, + MeanTripletBatchTripletLoss, +) from lib.models import * -from lib.models import (CACDResnet18, CACDResnet50, FGLigthModel, - LFWLightModel, LFWResNet18, NormalizedNet, ResNet18, - RetrievalAdapter) +from lib.models import ( + CACDResnet18, + CACDResnet50, + FGLigthModel, + LFWLightModel, + LFWResNet18, + NormalizedNet, + ResNet18, + RetrievalAdapter, +) from lib.sampler import CustomSampler -from lib.train_loggers import (CompoundLogger, InterClusterLogger, - IntraClusterLogger, LocalRankAtKLogger, - RankAtKLogger, SilentLogger, TrainLogger, - TripletLoggerOffline, TripletLoggerOnline) +from lib.train_loggers import ( + CompoundLogger, + InterClusterLogger, + IntraClusterLogger, + LocalRankAtKLogger, + RankAtKLogger, + SilentLogger, + TrainLogger, + TripletLoggerOffline, + TripletLoggerOnline, +) from lib.trainers import train_model_online from lib.visualizations import * @@ -272,7 +300,7 @@ def __init_wandb_params(self): wandb.init( project=GLOBALS.wandb_project_name, name=GLOBALS.wandb_run_name, - config=str(GLOBALS), + config=GLOBALS.dict(), ) # Functions that we are going to use