Skip to content

Commit

Permalink
feat #70: computing and displaying metrics after model training
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 21, 2024
1 parent 32ce7d9 commit 4ebeb77
Showing 1 changed file with 105 additions and 109 deletions.
214 changes: 105 additions & 109 deletions src/MNIST.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torchvision.transforms as transforms

import wandb
from lib import utils
from lib import embedding_to_classifier, metrics, utils


@dataclass
Expand Down Expand Up @@ -527,137 +527,133 @@ def extract_embeddings(dataloader, model):
# net.eval()


# TODO -- ADAM -- Run our model evaluation
# # Model evaluation
# # ==============================================================================


# # Use the network to perform a retrieval task and compute rank@1 and rank@5 accuracy
# with torch.no_grad():
# net.set_permute(False)

# train_rank_at_one = metrics.rank_accuracy(
# k=1,
# data_loader=train_loader_augmented,
# network=net,
# max_examples=len(train_loader_augmented),
# fast_implementation=False,
# )
# test_rank_at_one = metrics.rank_accuracy(
# k=1,
# data_loader=test_loader,
# network=net,
# max_examples=len(test_loader),
# fast_implementation=False,
# )
# train_rank_at_five = metrics.rank_accuracy(
# k=5,
# data_loader=train_loader_augmented,
# network=net,
# max_examples=len(train_loader_augmented),
# fast_implementation=False,
# )
# test_rank_at_five = metrics.rank_accuracy(
# k=5,
# data_loader=test_loader,
# network=net,
# max_examples=len(test_loader),
# fast_implementation=False,
# )
# Model evaluation
# ==============================================================================

# print(f"Train Rank@1 Accuracy: {train_rank_at_one}")
# print(f"Test Rank@1 Accuracy: {test_rank_at_one}")
# print(f"Train Rank@5 Accuracy: {train_rank_at_five}")
# print(f"Test Rank@5 Accuracy: {test_rank_at_five}")

# # Put this info in wandb
# wandb.log(
# {
# "Final Train Rank@1 Accuracy": train_rank_at_one,
# "Final Test Rank@1 Accuracy": test_rank_at_one,
# "Final Train Rank@5 Accuracy": train_rank_at_five,
# "Final Test Rank@5 Accuracy": test_rank_at_five,
# }
# )
# Use the network to perform a retrieval task and compute rank@1 and rank@5 accuracy
with torch.no_grad():
net.set_permute(False)

train_rank_at_one = metrics.rank_accuracy(
k=1,
data_loader=online_train_loader,
network=net,
max_examples=len(online_train_loader),
fast_implementation=False,
)
test_rank_at_one = metrics.rank_accuracy(
k=1,
data_loader=online_test_loader,
network=net,
max_examples=len(online_test_loader),
fast_implementation=False,
)
train_rank_at_five = metrics.rank_accuracy(
k=5,
data_loader=online_train_loader,
network=net,
max_examples=len(online_train_loader),
fast_implementation=False,
)
test_rank_at_five = metrics.rank_accuracy(
k=5,
data_loader=online_test_loader,
network=net,
max_examples=len(online_test_loader),
fast_implementation=False,
)

print("=> 📈 Final Metrics")
print(f"Train Rank@1 Accuracy: {train_rank_at_one}")
print(f"Test Rank@1 Accuracy: {test_rank_at_one}")
print(f"Train Rank@5 Accuracy: {train_rank_at_five}")
print(f"Test Rank@5 Accuracy: {test_rank_at_five}")
print("")

# net.set_permute(True)
# Put this info in wandb
wandb.log(
{
"Final Train Rank@1 Accuracy": train_rank_at_one,
"Final Test Rank@1 Accuracy": test_rank_at_one,
"Final Train Rank@5 Accuracy": train_rank_at_five,
"Final Test Rank@5 Accuracy": test_rank_at_five,
}
)

net.set_permute(True)

# # Compute the the *silhouette* metric for the produced embedding, on
# # train, validation and test set:
# with torch.no_grad():
# net.set_permute(False)

# # Try to clean memory, because we can easily run out of memory
# # This provoke the notebook to crash, and all in-memory objects to be lost
# try_to_clean_memory()
# Compute the the *silhouette* metric for the produced embedding, on
# train, validation and test set:
with torch.no_grad():
net.set_permute(False)

# train_silh = metrics.silhouette(train_loader_augmented, net)
# print(f"Silhouette in training loader: {train_silh}")
# Try to clean memory, because we can easily run out of memory
# This provoke the notebook to crash, and all in-memory objects to be lost
try_to_clean_memory()

# validation_silh = metrics.silhouette(validation_loader_augmented, net)
# print(f"Silhouette in validation loader: {validation_silh}")
print("=> 📈 Silhouette metrics")
train_silh = metrics.silhouette(online_train_loader, net)
print(f"Silhouette in training loader: {train_silh}")

# test_silh = metrics.silhouette(test_loader, net)
# print(f"Silhouette in test loader: {test_silh}")
test_silh = metrics.silhouette(online_test_loader, net)
print(f"Silhouette in test loader: {test_silh}")
print("")

# # Put this info in wandb
# wandb.log(
# {
# "Final Training silh": train_silh,
# "Final Validation silh": validation_silh,
# "Final Test silh": test_silh,
# }
# )
# Put this info in wandb
wandb.log(
{
"Final Training silh": train_silh,
"Final Test silh": test_silh,
}
)

# net.set_permute(True)
net.set_permute(True)


# # Show the "criterion" metric on test set
# with torch.no_grad():
# net.set_permute(False)
# Now take the classifier from the embedding and use it to compute some classification metrics:
with torch.no_grad():
# Try to clean memory, because we can easily run out of memory
# This provoke the notebook to crash, and all in-memory objects to be lost
try_to_clean_memory()

# core.test_model_online(net, test_loader, parameters["criterion"], online=True)
# With hopefully enough memory, try to convert the embedding to a classificator
number_neigbours = 3
classifier = embedding_to_classifier.EmbeddingToClassifier(
net,
k=number_neigbours,
data_loader=online_train_loader,
embedding_dimension=2,
)

# net.set_permute(True)
# See how it works on a small test set
with torch.no_grad():
net.set_permute(False)

# Show only `max_iterations` classifications
counter = 0
max_iterations = len(test_dataset)

# # Now take the classifier from the embedding and use it to compute some classification metrics:
# with torch.no_grad():
# # Try to clean memory, because we can easily run out of memory
# # This provoke the notebook to crash, and all in-memory objects to be lost
# try_to_clean_memory()

# # With hopefully enough memory, try to convert the embedding to a classificator
# classifier = EmbeddingToClassifier(
# net,
# k=GLOBALS["NUMBER_NEIGHBOURS"],
# data_loader=train_loader_augmented,
# embedding_dimension=GLOBALS["EMBEDDING_DIMENSION"],
# )
correct = 0

# # See how it works on a small test set
# with torch.no_grad():
# net.set_permute(False)
for img, img_class in test_dataset:
predicted_class = classifier.predict(img)

# # Show only `max_iterations` classifications
# counter = 0
# max_iterations = 20
if img_class == predicted_class[0]:
correct += 1

# for img, img_class in test_dataset:
# predicted_class = classifier.predict(img)
# print(
# f"True label: {img_class}, predicted label: {predicted_class[0]}, correct: {img_class == predicted_class[0]}"
# )
counter += 1
if counter == max_iterations:
break

# counter += 1
# if counter == max_iterations:
# break
accuracy = correct / max_iterations
print(f"=> 📈 Metrics on {max_iterations} test images")
print(f"Accuracy: {(accuracy * 100):.3f}% ")

# net.set_permute(True)
net.set_permute(True)


# TODO -- ADAM -- run our plot of the embedding
# TODO -- ADAM -- run our plot of the embedding
# # Plot of the embedding
# # ==============================================================================
# #
Expand Down

0 comments on commit 4ebeb77

Please sign in to comment.