-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathutils.py
66 lines (51 loc) · 1.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gzip
from sklearn import svm
from sklearn.metrics import accuracy_score
import numpy as np
import torch
def load_data(data_file):
"""loads the data from the gzip pickled files, and converts to numpy arrays"""
print('loading data ...')
f = gzip.open(data_file, 'rb')
train_set, valid_set, test_set = load_pickle(f)
f.close()
train_set_x, train_set_y = make_tensor(train_set)
valid_set_x, valid_set_y = make_tensor(valid_set)
test_set_x, test_set_y = make_tensor(test_set)
return [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)]
def make_tensor(data_xy):
"""converts the input to numpy arrays"""
data_x, data_y = data_xy
data_x = torch.tensor(data_x)
data_y = np.asarray(data_y, dtype='int32')
return data_x, data_y
def svm_classify(data, C):
"""
trains a linear SVM on the data
input C specifies the penalty factor of SVM
"""
train_data, _, train_label = data[0]
valid_data, _, valid_label = data[1]
test_data, _, test_label = data[2]
print('training SVM...')
clf = svm.LinearSVC(C=C, dual=False)
clf.fit(train_data, train_label.ravel())
p = clf.predict(test_data)
test_acc = accuracy_score(test_label, p)
p = clf.predict(valid_data)
valid_acc = accuracy_score(valid_label, p)
return [test_acc, valid_acc]
def load_pickle(f):
"""
loads and returns the content of a pickled file
it handles the inconsistencies between the pickle packages available in Python 2 and 3
"""
try:
import cPickle as thepickle
except ImportError:
import _pickle as thepickle
try:
ret = thepickle.load(f, encoding='latin1')
except TypeError:
ret = thepickle.load(f)
return ret