-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathmain.py
86 lines (51 loc) · 2.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from load import *
import torchmetrics
from tqdm import tqdm
seed_everything(hparams['seed'])
bs = hparams['batch_size']
dataloader = DataLoader(dataset, bs, shuffle=True, num_workers=16, pin_memory=True)
print("Loading model...")
device = torch.device(hparams['device'])
# load model
model, preprocess = clip.load(hparams['model_size'], device=device, jit=False)
model.eval()
model.requires_grad_(False)
print("Encoding descriptions...")
description_encodings = compute_description_encodings(model)
label_encodings = compute_label_encodings(model)
print("Evaluating...")
lang_accuracy_metric = torchmetrics.Accuracy().to(device)
lang_accuracy_metric_top5 = torchmetrics.Accuracy(top_k=5).to(device)
clip_accuracy_metric = torchmetrics.Accuracy().to(device)
clip_accuracy_metric_top5 = torchmetrics.Accuracy(top_k=5).to(device)
for batch_number, batch in enumerate(tqdm(dataloader)):
images, labels = batch
images = images.to(device)
labels = labels.to(device)
image_encodings = model.encode_image(images)
image_encodings = F.normalize(image_encodings)
image_labels_similarity = image_encodings @ label_encodings.T
clip_predictions = image_labels_similarity.argmax(dim=1)
clip_acc = clip_accuracy_metric(image_labels_similarity, labels)
clip_acc_top5 = clip_accuracy_metric_top5(image_labels_similarity, labels)
image_description_similarity = [None]*n_classes
image_description_similarity_cumulative = [None]*n_classes
for i, (k, v) in enumerate(description_encodings.items()): # You can also vectorize this; it wasn't much faster for me
dot_product_matrix = image_encodings @ v.T
image_description_similarity[i] = dot_product_matrix
image_description_similarity_cumulative[i] = aggregate_similarity(image_description_similarity[i])
# create tensor of similarity means
cumulative_tensor = torch.stack(image_description_similarity_cumulative,dim=1)
descr_predictions = cumulative_tensor.argmax(dim=1)
lang_acc = lang_accuracy_metric(cumulative_tensor.softmax(dim=-1), labels)
lang_acc_top5 = lang_accuracy_metric_top5(cumulative_tensor.softmax(dim=-1), labels)
print("\n")
accuracy_logs = {}
accuracy_logs["Total Description-based Top-1 Accuracy: "] = 100*lang_accuracy_metric.compute().item()
accuracy_logs["Total Description-based Top-5 Accuracy: "] = 100*lang_accuracy_metric_top5.compute().item()
accuracy_logs["Total CLIP-Standard Top-1 Accuracy: "] = 100*clip_accuracy_metric.compute().item()
accuracy_logs["Total CLIP-Standard Top-5 Accuracy: "] = 100*clip_accuracy_metric_top5.compute().item()
# print the dictionary
print("\n")
for key, value in accuracy_logs.items():
print(key, value)