-
Notifications
You must be signed in to change notification settings - Fork 0
/
SVM_Implementation.m
65 lines (49 loc) · 1.93 KB
/
SVM_Implementation.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
%% Determine correlation of variables, and determine number exceeding
% specified threshold
%https://www.analyticsvidhya.com/blog/2016/12/introduction-to-feature-selection-methods-with-an-example-or-how-to-select-the-right-variables/
%https://www.mathworks.com/matlabcentral/fileexchange/50701-feature-selection-with-svm-rfe
%https://www.csie.ntu.edu.tw/~cjlin/libsvm/
load allVarsMat.mat
addpath(genpath('SVM-RFE-CBR-v1.3'));
addpath(genpath('libsvm-3.23'));
varsCorr = corr(allVarsMat(:,1:end-1));
varsLabelsCorr = corr(allVarsMat(:,1:end-1),allVarsMat(:,end));
threshold = 0.8;
numHighCorrelation = size(find(abs(varsCorr)>threshold));
%% Run the SVM-RFE Method on the variable matrix and obtain the variable
% ranking
featureVect = allVarsMat(:,1:end-1);
labelVect = allVarsMat(:,end);
param = {};
param.kerType = 2;
param.rfeC = 1;
param.rfeG = 1/size(featureVect,2);
param.useCBR = 0;
param.Rth = 0.9;
param.nstopChunk = Inf;
[ftRank,ftScore] = ftSel_SVMRFECBR(featureVect,labelVect, param);
%% Modify Table to include top n variables from ftRank [new table: red(uced)MastTable]
redTable = reduceFeatTable(newMastTable, 0.75, ftRank);
% varsInd2Del = ftRank(cutoffThreshold:end);
% varsInd2Del = sort(varsInd2Del, 'Descend');
% redMastTable = newMastTable;
% for i = varsInd2Del
% redMastTable.(i) = [];
% end
%% Train/Cross-Validate SVM model using reduced Feature Vector
SVMModel = fitcsvm(redTable, 'Labels', 'Standardize',true,'KernelFunction','RBF', 'KernelScale','auto');
CVSVMModel = crossval(SVMModel);
classLoss = kfoldLoss(CVSVMModel)
%% Find Classification error on your own with small subset
randZ = randperm(height(redTable));
testSet = randZ(1:5000);
numIncorrect = 0;
for i = testSet
userDataLabel = redTable(i,end);
userData = redTable(i,1:end-1);
predResult = predict(SVMModel, userData);
if predResult ~= userDataLabel.(1)
numIncorrect = numIncorrect+1;
end
end
save('twitterSVMClassifier', 'SVMModel')