-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtest.js
127 lines (109 loc) · 3.04 KB
/
test.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import { getNumbers, getClasses } from 'ml-dataset-iris';
import KNN from '..';
describe('knn', () => {
const cases = [
[0, 0, 0],
[0, 1, 1],
[1, 1, 0],
[2, 2, 2],
[1, 2, 2],
[2, 1, 2],
];
const labels = [0, 0, 0, 1, 1, 1];
const knn = new KNN(cases, labels, {
k: 3,
});
it('predictions', () => {
const result = knn.predict([
[1.81, 1.81, 1.81],
[0.5, 0.5, 0.5],
]);
expect(result[0]).toBe(1);
expect(result[1]).toBe(0);
expect(knn.predict([1.81, 1.81, 1.81])).toBe(1);
});
it('type error', () => {
const throwMessage = 'dataset to predict must be an array or a matrix';
expect(() => knn.predict()).toThrow(throwMessage);
expect(() => knn.predict([])).toThrow(throwMessage);
expect(() => knn.predict(['a'])).toThrow(throwMessage);
expect(() => knn.predict([[]])).toThrow(throwMessage);
expect(() => knn.predict([['a']])).toThrow(throwMessage);
});
it('load', () => {
const model = JSON.parse(JSON.stringify(knn));
const newKnn = KNN.load(model);
const result = newKnn.predict([
[1.81, 1.81, 1.81],
[0.5, 0.5, 0.5],
]);
expect(result[0]).toBe(1);
expect(result[1]).toBe(0);
expect(knn.predict([1.81, 1.81, 1.81])).toBe(1);
});
it('load errors', () => {
expect(() => KNN.load({})).toThrow('invalid model: undefined');
expect(() => KNN.load({ name: 'KNN', isEuclidean: true }, () => 1)).toThrow(
'the model was created with the default distance function. Do not load it with another one',
);
expect(() => KNN.load({ name: 'KNN', isEuclidean: false })).toThrow(
'a custom distance function was used to create the model. Please provide it again',
);
});
it('Test with iris dataset', () => {
let data = getNumbers();
let labels = getClasses();
let knn = new KNN(data, labels, { k: 5 });
let test = [
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5.0, 3.6, 1.4, 0.2],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3.0, 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3.0, 5.0, 1.7],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3.0, 5.2, 2.3],
[6.3, 2.5, 5.0, 1.9],
[6.5, 3.0, 5.2, 2.0],
];
knn = KNN.load(JSON.parse(JSON.stringify(knn)));
let expected = [
'setosa',
'setosa',
'setosa',
'setosa',
'setosa',
'versicolor',
'versicolor',
'versicolor',
'versicolor',
'versicolor',
'virginica',
'virginica',
'virginica',
'virginica',
'virginica',
];
expect(knn.predict(test)).toStrictEqual(expected);
});
it('default k', () => {
const dataset = [
[0, 0, 0],
[0, 1, 1],
[1, 1, 0],
[2, 2, 2],
[1, 2, 2],
[2, 1, 2],
];
const predictions = [0, 0, 0, 1, 1, 1];
const knn = new KNN(dataset, predictions);
expect(knn.k).toBe(3);
let ans = knn.predict([[0, 0, 0]]);
expect(ans).toStrictEqual([0]);
});
});