-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinearClassifier.py
84 lines (64 loc) · 2.46 KB
/
linearClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Author Alvaro Esperanca
from SVM import SVM
from Validator import Validator
from LinearKernel import LinearKernel
import numpy as np
import sys
def main(args):
if len(args) != 3:
print "Insufficient arguments provided"
print "Terminating..."
return
loadModel = args[0]
trainFilename = args[1]
testFilename = args[2]
resutlsFilename = args[2].split('.')[0] + '_linear_results.txt'
resultStatsFilename = args[2].split('.')[0] + '_linear_results_stats.txt'
# Loading train and test data
train_data = np.genfromtxt(trainFilename, dtype=int ,delimiter='\t')
test_data = np.genfromtxt(testFilename, dtype=int, delimiter='\t')
trainEndIndex = len(train_data[0]) - 1
testEndIndex = len(test_data[0]) - 1
tempData = list()
tempLabels = list()
tempTest = list()
tempTestLabels = list()
for i in range(len(train_data)):
tempData.append(train_data[i][0:trainEndIndex])
if train_data[i][trainEndIndex] == 0:
tempLabels.append(-1)
else:
tempLabels.append(1)
for i in range(len(test_data)):
tempTest.append(test_data[i][0:testEndIndex])
if test_data[i][testEndIndex] == 0:
tempTestLabels.append(-1)
else:
tempTestLabels.append(1)
training_data = np.array(tempData)
training_labels = np.array(tempLabels)
testing_data = np.array(tempTest)
validationLabels = np.array(tempTestLabels, 'd')
clf = SVM(kernel=LinearKernel(), C=1.0)
val = Validator()
if loadModel == "-l":
clf.loadModel()
else:
clf.fit(training_data, training_labels)
clf.saveModel()
predictions = clf.predict(testing_data)
val.validate(validationLabels, predictions)
predFile = open(resutlsFilename, "w")
statFile = open(resultStatsFilename, "w")
predFile.write("Predicted\tActual\n")
for i in range(len(predictions)):
predFile.write("%d\t%d\n" % (predictions[i],validationLabels[i]) )
statFile.write("%-20s %-5d\n" % ("True Positives:", val.truePositives()) )
statFile.write("%-20s %-5d\n" % ("True Negatives:", val.trueNegatives()) )
statFile.write("%-20s %-5d\n" % ("False Positives:", val.falsePositives()) )
statFile.write("%-20s %-5d\n\n" % ("False Negatives:", val.falseNegatives()) )
statFile.write("%-20s %-2.2f\n" % ("Accuracy:", val.accuracy()) )
predFile.close()
statFile.close()
if __name__ == "__main__":
main(sys.argv[1:])