-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathdecision-tree.js
394 lines (337 loc) · 12.3 KB
/
decision-tree.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
var dt = (function () {
/**
* Creates an instance of DecisionTree
*
* @constructor
* @param builder - contains training set and
* some configuration parameters
*/
function DecisionTree(builder) {
this.root = buildDecisionTree({
trainingSet: builder.trainingSet,
ignoredAttributes: arrayToHashSet(builder.ignoredAttributes),
categoryAttr: builder.categoryAttr || 'category',
minItemsCount: builder.minItemsCount || 1,
entropyThrehold: builder.entropyThrehold || 0.01,
maxTreeDepth: builder.maxTreeDepth || 70
});
}
DecisionTree.prototype.predict = function (item) {
return predict(this.root, item);
}
/**
* Creates an instance of RandomForest
* with specific number of trees
*
* @constructor
* @param builder - contains training set and some
* configuration parameters for
* building decision trees
*/
function RandomForest(builder, treesNumber) {
this.trees = buildRandomForest(builder, treesNumber);
}
RandomForest.prototype.predict = function (item) {
return predictRandomForest(this.trees, item);
}
/**
* Transforming array to object with such attributes
* as elements of array (afterwards it can be used as HashSet)
*/
function arrayToHashSet(array) {
var hashSet = {};
if (array) {
for(var i in array) {
var attr = array[i];
hashSet[attr] = true;
}
}
return hashSet;
}
/**
* Calculating how many objects have the same
* values of specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function countUniqueValues(items, attr) {
var counter = {};
// detecting different values of attribute
for (var i = items.length - 1; i >= 0; i--) {
// items[i][attr] - value of attribute
counter[items[i][attr]] = 0;
}
// counting number of occurrences of each of values
// of attribute
for (var i = items.length - 1; i >= 0; i--) {
counter[items[i][attr]] += 1;
}
return counter;
}
/**
* Calculating entropy of array of objects
* by specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function entropy(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var entropy = 0;
var p;
for (var i in counter) {
p = counter[i] / items.length;
entropy += -p * Math.log(p);
}
return entropy;
}
/**
* Splitting array of objects by value of specific attribute,
* using specific predicate and pivot.
*
* Items which matched by predicate will be copied to
* the new array called 'match', and the rest of the items
* will be copied to array with name 'notMatch'
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*
* @param predicate - function(x, y)
* which returns 'true' or 'false'
*
* @param pivot - used as the second argument when
* calling predicate function:
* e.g. predicate(item[attr], pivot)
*/
function split(items, attr, predicate, pivot) {
var match = [];
var notMatch = [];
var item,
attrValue;
for (var i = items.length - 1; i >= 0; i--) {
item = items[i];
attrValue = item[attr];
if (predicate(attrValue, pivot)) {
match.push(item);
} else {
notMatch.push(item);
}
};
return {
match: match,
notMatch: notMatch
};
}
/**
* Finding value of specific attribute which is most frequent
* in given array of objects.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function mostFrequentValue(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var mostFrequentCount = 0;
var mostFrequentValue;
for (var value in counter) {
if (counter[value] > mostFrequentCount) {
mostFrequentCount = counter[value];
mostFrequentValue = value;
}
};
return mostFrequentValue;
}
var predicates = {
'==': function (a, b) { return a == b },
'>=': function (a, b) { return a >= b }
};
/**
* Function for building decision tree
*/
function buildDecisionTree(builder) {
var trainingSet = builder.trainingSet;
var minItemsCount = builder.minItemsCount;
var categoryAttr = builder.categoryAttr;
var entropyThrehold = builder.entropyThrehold;
var maxTreeDepth = builder.maxTreeDepth;
var ignoredAttributes = builder.ignoredAttributes;
if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
// restriction by maximal depth of tree
// or size of training set is to small
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
var initialEntropy = entropy(trainingSet, categoryAttr);
if (initialEntropy <= entropyThrehold) {
// entropy of training set too small
// (it means that training set is almost homogeneous),
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
// used as hash-set for avoiding the checking of split by rules
// with the same 'attribute-predicate-pivot' more than once
var alreadyChecked = {};
// this variable expected to contain rule, which splits training set
// into subsets with smaller values of entropy (produces informational gain)
var bestSplit = {gain: 0};
for (var i = trainingSet.length - 1; i >= 0; i--) {
var item = trainingSet[i];
// iterating over all attributes of item
for (var attr in item) {
if ((attr == categoryAttr) || ignoredAttributes[attr]) {
continue;
}
// let the value of current attribute be the pivot
var pivot = item[attr];
// pick the predicate
// depending on the type of the attribute value
var predicateName;
if (typeof pivot == 'number') {
predicateName = '>=';
} else {
// there is no sense to compare non-numeric attributes
// so we will check only equality of such attributes
predicateName = '==';
}
var attrPredPivot = attr + predicateName + pivot;
if (alreadyChecked[attrPredPivot]) {
// skip such pairs of 'attribute-predicate-pivot',
// which been already checked
continue;
}
alreadyChecked[attrPredPivot] = true;
var predicate = predicates[predicateName];
// splitting training set by given 'attribute-predicate-value'
var currSplit = split(trainingSet, attr, predicate, pivot);
// calculating entropy of subsets
var matchEntropy = entropy(currSplit.match, categoryAttr);
var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr);
// calculating informational gain
var newEntropy = 0;
newEntropy += matchEntropy * currSplit.match.length;
newEntropy += notMatchEntropy * currSplit.notMatch.length;
newEntropy /= trainingSet.length;
var currGain = initialEntropy - newEntropy;
if (currGain > bestSplit.gain) {
// remember pairs 'attribute-predicate-value'
// which provides informational gain
bestSplit = currSplit;
bestSplit.predicateName = predicateName;
bestSplit.predicate = predicate;
bestSplit.attribute = attr;
bestSplit.pivot = pivot;
bestSplit.gain = currGain;
}
}
}
if (!bestSplit.gain) {
// can't find optimal split
return { category: mostFrequentValue(trainingSet, categoryAttr) };
}
// building subtrees
builder.maxTreeDepth = maxTreeDepth - 1;
builder.trainingSet = bestSplit.match;
var matchSubTree = buildDecisionTree(builder);
builder.trainingSet = bestSplit.notMatch;
var notMatchSubTree = buildDecisionTree(builder);
return {
attribute: bestSplit.attribute,
predicate: bestSplit.predicate,
predicateName: bestSplit.predicateName,
pivot: bestSplit.pivot,
match: matchSubTree,
notMatch: notMatchSubTree,
matchedCount: bestSplit.match.length,
notMatchedCount: bestSplit.notMatch.length
};
}
/**
* Classifying item, using decision tree
*/
function predict(tree, item) {
var attr,
value,
predicate,
pivot;
// Traversing tree from the root to leaf
while(true) {
if (tree.category) {
// only leafs contains predicted category
return tree.category;
}
attr = tree.attribute;
value = item[attr];
predicate = tree.predicate;
pivot = tree.pivot;
// move to one of subtrees
if (predicate(value, pivot)) {
tree = tree.match;
} else {
tree = tree.notMatch;
}
}
}
/**
* Building array of decision trees
*/
function buildRandomForest(builder, treesNumber) {
var items = builder.trainingSet;
// creating training sets for each tree
var trainingSets = [];
for (var t = 0; t < treesNumber; t++) {
trainingSets[t] = [];
}
for (var i = items.length - 1; i >= 0 ; i--) {
// assigning items to training sets of each tree
// using 'round-robin' strategy
var correspondingTree = i % treesNumber;
trainingSets[correspondingTree].push(items[i]);
}
// building decision trees
var forest = [];
for (var t = 0; t < treesNumber; t++) {
builder.trainingSet = trainingSets[t];
var tree = new DecisionTree(builder);
forest.push(tree);
}
return forest;
}
/**
* Each of decision tree classifying item
* ('voting' that item corresponds to some class).
*
* This function returns hash, which contains
* all classifying results, and number of votes
* which were given for each of classifying results
*/
function predictRandomForest(forest, item) {
var result = {};
for (var i in forest) {
var tree = forest[i];
var prediction = tree.predict(item);
result[prediction] = result[prediction] ? result[prediction] + 1 : 1;
}
return result;
}
var exports = {};
exports.DecisionTree = DecisionTree;
exports.RandomForest = RandomForest;
return exports;
})();