Skip to content

Commit

Permalink
feat #70: plotting the embedding that we've learned
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 21, 2024
1 parent 4ebeb77 commit 21b29b8
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/MNIST.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,17 @@ def extract_embeddings(dataloader, model):

accuracy = correct / max_iterations
print(f"=> 📈 Metrics on {max_iterations} test images")
print(f"Accuracy: {(accuracy * 100):.3f}% ")
print(f"Accuracy: {(accuracy * 100):.3f}%")
print("")

net.set_permute(True)


# TODO -- ADAM -- run our plot of the embedding
# # Plot of the embedding
# # ==============================================================================
# #
# # - If the dimension of the embedding is 2, then we can plot how the transformation to a classificator works:
# # - That logic is encoded in the `scatter_plot` method
# with torch.no_grad():
# classifier.scatter_plot()
# Plot of the embedding
# ==============================================================================
#
# - If the dimension of the embedding is 2, then we can plot how the transformation to a classificator works:
# - That logic is encoded in the `scatter_plot` method
with torch.no_grad():
print("=> Plotting the embedding that we've learned")
classifier.scatter_plot(os.path.join(GLOBALS.plots_path, "embedding.png"))

0 comments on commit 21b29b8

Please sign in to comment.