Skip to content

Commit

Permalink
fix, feat #70: small mistakes done in the prev. refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 20, 2024
1 parent f7e89ec commit 618447a
Showing 1 changed file with 42 additions and 14 deletions.
56 changes: 42 additions & 14 deletions src/MNIST.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datetime
import os
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Dict


@dataclass
Expand Down Expand Up @@ -144,6 +144,15 @@ def __init_wandb_params(self):
self.wandb_project_name = "MNIST dataset"
self.wandb_run_name = str(datetime.datetime.now())

def dict(self) -> Dict:
"""
Wandb need a dictionary representation of the class for saving values in
their system. So this method tries its best to create a dict repr for
this data class
"""

return self.__dict__


GLOBALS = GlobalParameters()

Expand All @@ -160,8 +169,8 @@ def __init_wandb_params(self):

# TODO -- remove this
sys.path.append(GLOBALS.base_path)
sys.path.append(GLOBALS.base_path, "src")
sys.path.append(GLOBALS.base_path, "src/lib")
sys.path.append(os.path.join(GLOBALS.base_path, "src"))
sys.path.append(os.path.join(GLOBALS.base_path, "src/lib"))


# Importing the modules we are going to use
Expand Down Expand Up @@ -192,9 +201,11 @@ def __init_wandb_params(self):
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets

# For using pre-trained ResNets
import torchvision.models as models
import torchvision.transforms as transforms

# All concrete pieces we're using form sklearn
from sklearn.metrics import accuracy_score, roc_auc_score, silhouette_score
from torch.utils.data import DataLoader, Dataset
Expand All @@ -216,18 +227,35 @@ def __init_wandb_params(self):
import wandb
from lib.data_augmentation import AugmentatedDataset, LazyAugmentatedDataset
from lib.embedding_to_classifier import EmbeddingToClassifier
from lib.loss_functions import (AddSmallEmbeddingPenalization,
BatchAllTripletLoss, BatchHardTripletLoss,
MeanTripletBatchTripletLoss)
from lib.loss_functions import (
AddSmallEmbeddingPenalization,
BatchAllTripletLoss,
BatchHardTripletLoss,
MeanTripletBatchTripletLoss,
)
from lib.models import *
from lib.models import (CACDResnet18, CACDResnet50, FGLigthModel,
LFWLightModel, LFWResNet18, NormalizedNet, ResNet18,
RetrievalAdapter)
from lib.models import (
CACDResnet18,
CACDResnet50,
FGLigthModel,
LFWLightModel,
LFWResNet18,
NormalizedNet,
ResNet18,
RetrievalAdapter,
)
from lib.sampler import CustomSampler
from lib.train_loggers import (CompoundLogger, InterClusterLogger,
IntraClusterLogger, LocalRankAtKLogger,
RankAtKLogger, SilentLogger, TrainLogger,
TripletLoggerOffline, TripletLoggerOnline)
from lib.train_loggers import (
CompoundLogger,
InterClusterLogger,
IntraClusterLogger,
LocalRankAtKLogger,
RankAtKLogger,
SilentLogger,
TrainLogger,
TripletLoggerOffline,
TripletLoggerOnline,
)
from lib.trainers import train_model_online
from lib.visualizations import *

Expand Down Expand Up @@ -272,7 +300,7 @@ def __init_wandb_params(self):
wandb.init(
project=GLOBALS.wandb_project_name,
name=GLOBALS.wandb_run_name,
config=str(GLOBALS),
config=GLOBALS.dict(),
)

# Functions that we are going to use
Expand Down

0 comments on commit 618447a

Please sign in to comment.