forked from BrainJS/brain.js
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cross-validate.ts
44 lines (37 loc) · 1.3 KB
/
cross-validate.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import * as assert from 'assert';
import * as brain from '../index';
const trainingData = [
// xor data, repeating to simulate that we have a lot of data
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] }
];
const netOptions = {
hiddenLayers: [3]
} as brain.INeuralNetworkOptions;
const trainingOptions = {
iterations: 20000,
log: details => console.log(details)
} as brain.INeuralNetworkTrainingOptions;
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, netOptions);
const stats = crossValidate.train(trainingData, trainingOptions);
console.log(stats);
const net = crossValidate.toNeuralNetwork();
const result01 = net.run([0, 1]);
const result00 = net.run([0, 0]);
const result11 = net.run([1, 1]);
const result10 = net.run([1, 0]);
assert(result01[0] > 0.9);
assert(result00[0] < 0.1);
assert(result11[0] < 0.1);
assert(result10[0] > 0.9);
console.log('0 XOR 1: ', result01); // 0.987
console.log('0 XOR 0: ', result00); // 0.058
console.log('1 XOR 1: ', result11); // 0.087
console.log('1 XOR 0: ', result10); // 0.934