forked from ryanchankh/mcr2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate.py
89 lines (72 loc) · 3.56 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import glob
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from evaluate import svm
from loss import MaximalCodingRateReduction
import train_func as tf
import utils
def gen_testloss(args):
# load data and model
params = utils.load_params(args.model_dir)
ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
ckpt_paths = np.sort(ckpt_paths)
# csv
headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e",
"discrimn_loss_t", "compress_loss_t"]
csv_path = utils.create_csv(args.model_dir, 'losses_test.csv', headers)
print('writing to:', csv_path)
# load data
test_transforms = tf.load_transforms('test')
testset = tf.load_trainset(params['data'], test_transforms, train=False)
testloader = DataLoader(testset, batch_size=params['bs'], shuffle=False, num_workers=4)
# save loss
criterion = MaximalCodingRateReduction(gam1=params['gam1'], gam2=params['gam2'], eps=params['eps'])
for epoch, ckpt_path in enumerate(ckpt_paths):
net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
for step, (batch_imgs, batch_lbls) in enumerate(testloader):
features = net(batch_imgs.cuda())
loss, loss_empi, loss_theo = criterion(features, batch_lbls,
num_classes=len(testset.classes))
utils.save_state(args.model_dir, epoch, step, loss.item(),
*loss_empi, *loss_theo, filename='losses_test.csv')
print("Finished generating test loss.")
def gen_accuracy(args):
# load data and model
params = utils.load_params(args.model_dir)
ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
ckpt_paths = np.sort(ckpt_paths)
# csv
headers = ["epoch", "acc_train", "acc_test"]
csv_path = utils.create_csv(args.model_dir, 'accuracy.csv', headers)
for epoch, ckpt_paths in enumerate(ckpt_paths):
if epoch % 5 != 0:
continue
net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
# load data
train_transforms = tf.load_transforms('test')
trainset = tf.load_trainset(params['data'], train_transforms, train=True)
trainloader = DataLoader(trainset, batch_size=500, num_workers=4)
train_features, train_labels = tf.get_features(net, trainloader, verbose=False)
test_transforms = tf.load_transforms('test')
testset = tf.load_trainset(params['data'], test_transforms, train=False)
testloader = DataLoader(testset, batch_size=500, num_workers=4)
test_features, test_labels = tf.get_features(net, testloader, verbose=False)
acc_train, acc_test = svm(args, train_features, train_labels, test_features, test_labels)
utils.save_state(args.model_dir, epoch, acc_train, acc_test, filename='accuracy.csv')
print("Finished generating accuracy.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generating files')
parser.add_argument('--model_dir', type=str, help='base directory for saving PyTorch model.')
parser.add_argument('--test', help='create losses_test.csv', action='store_true')
parser.add_argument('--acc', help='create accuracy.csv', action='store_true')
args = parser.parse_args()
if args.test:
gen_testloss(args)
if args.acc:
gen_accuracy(args)