forked from tommyngx/hand_words
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.js
143 lines (121 loc) · 4.2 KB
/
model.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/**
* class Model
* Loads the Tensorflow model and preprocesses and predicts images
*/
class Model {
/**
* Initializes the Model class, loads and warms up the model, etc
*/
constructor() {
this.alphabet = "abcdefghijklmnopqrstuvwxyz";
this.characters = "0123456789" + this.alphabet.toUpperCase() + this.alphabet
this.inputCanvas = document.getElementById("input-canvas")
this.isWarmedUp = this.loadModel()
.then(this.warmUp.bind(this))
.then(() => console.info("Backend running on:", tf.getBackend()))
}
/**
* Loads the model
*/
loadModel() {
console.time("Load model")
return tf.loadLayersModel("model/model.json").then(model => {
this._model = model;
console.timeEnd("Load model")
})
}
/**
* Runs a prediction with random data to warm up the GPU
*/
warmUp() {
console.time("Warmup")
this._model.predict(tf.randomNormal([1,28,28,1])).as1D().dataSync()
this.isWarmedUp = true;
console.timeEnd("Warmup")
}
/**
* Takes an ImageData object and reshapes it to fit the model
* @param {ImageData} pixelData
*/
preprocessImage(pixelData) {
const targetDim = 28,
edgeSize = 2,
resizeDim = targetDim-edgeSize*2,
padVertically = pixelData.width > pixelData.height,
padSize = Math.round((Math.max(pixelData.width, pixelData.height) - Math.min(pixelData.width, pixelData.height))/2),
padSquare = padVertically ? [[padSize,padSize], [0,0], [0,0]] : [[0,0], [padSize,padSize], [0,0]];
let tempImg = null;
// remove the previous image to avoid memory leak
if(tempImg) tempImg.dispose();
return tf.tidy(() => {
// convert the pixel data into a tensor with 1 data channel per pixel
// i.e. from [h, w, 4] to [h, w, 1]
let tensor = tf.browser.fromPixels(pixelData, 1)
// pad it such that w = h = max(w, h)
.pad(padSquare, 255.0)
// scale it down
tensor = tf.image.resizeBilinear(tensor, [resizeDim, resizeDim])
// pad it with blank pixels along the edges (to better match the training data)
.pad([[edgeSize,edgeSize], [edgeSize,edgeSize], [0,0]], 255.0)
// invert and normalize to match training data
tensor = tf.scalar(1.0).sub(tensor.toFloat().div(tf.scalar(255.0)))
// display what the model will see (keeping the tensor outside the tf.tidy scope is necessary)
tempImg = tf.keep(tf.clone(tensor))
this.showInput(tempImg)
// Reshape again to fit training model [N, 28, 28, 1]
// where N = 1 in this case
return tensor.expandDims(0)
});
}
/**
* Takes an ImageData objects and predict a character
* @param {ImageData} pixelData
* @returns {string} character
*/
predict(pixelData) {
if(!this._model) return console.warn("Model not loaded yet!");
console.time("Prediction")
let tensor = this.preprocessImage(pixelData),
prediction = this._model.predict(tensor).as1D(),
// get the index of the most probable character
argMax = prediction.argMax().dataSync()[0],
probability = prediction.max().dataSync()[0],
// get the character at that index
character = this.characters[argMax];
console.log("Predicted", character, "Probability", probability)
console.timeEnd("Prediction")
return [character, probability]
}
/**
* Helper function to clean previously predicted images
*/
clearInput() {
[...this.inputCanvas.parentElement.getElementsByTagName("img")].map(el => el.remove())
this.inputCanvas.getContext('2d').clearRect(0, 0, this.inputCanvas.width, this.inputCanvas.height)
}
/**
* Takes a tensor and displays it on a canvas and displays the
* previous canvas rendering as an image
* @param {tensor} tempImg
*/
showInput(tempImg) {
let legacyImg = new Image
legacyImg.src = this.inputCanvas.toDataURL("image/png")
this.inputCanvas.parentElement.insertBefore(legacyImg, this.inputCanvas)
tf.browser.toPixels(tempImg, this.inputCanvas)
}
/**
* Helper function, to easier debug tensors
* @param {string} name
* @param {tf.tensor} tensor
* @param {int} width
* @param {int} height
*/
static log(name, tensor, width = 28, height = 28) {
tensor = tensor.dataSync()
console.log("Tensor name", name, tensor)
for(let i = 0; i<width*height; i+=width) {
console.log(tensor.slice(i,i+width).reduce((acc, cur) => acc + ((cur === 0 ? "0" : "1")+"").padStart(2)), "")
}
}
}