Skip to content

Commit

Permalink
Merge pull request #790 from ml5js/univeral-sentence-encoder
Browse files Browse the repository at this point in the history
[Univeral sentence encoder] Ports tfjs Univeral sentence encoder model
  • Loading branch information
joeyklee authored Apr 8, 2020
2 parents 2c71efb + e876c2d commit dad91c9
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/examples.json

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<html>
<head>
<meta charset="UTF-8" >
<title>Universal Sentence Encoder with Tokenizer</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.0.0/p5.min.js"></script>
<script src="http://localhost:8080/ml5.js" type="text/javascript"></script>
</head>

<body>
<h1>Universal Sentence Encoder with Tokenizer</h1>
<script src="sketch.js"></script>
</body>
</html>
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
let sentenceEncoder;
const sentences = [
'I love rainbows',
'I love rainbows too',
'I love cupcakes',
'I love bagels more'
]

function setup(){
createCanvas(512, 512);
// background(220);
colorMode(HSB, 360, 100, 100);
sentenceEncoder = ml5.universalSentenceEncoder(modelLoaded)
}

function modelLoaded(){

predict();
}

function predict(){

sentenceEncoder.predict(sentences, gotResults);
}

function gotResults(err, result){
if(err){
return err;
}
console.log(result);

result.forEach( (item, y) => {
// console.log(item);
item.forEach( (val, x) => {
const l = map(val, -1, 1, 0, 100);
noStroke();
fill(360, 100, l);
rect(x, y * (height/result.length) , 1, (height/result.length));
})
})

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<html>
<head>
<meta charset="UTF-8" >
<title>Universal Sentence Encoder</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.0.0/p5.min.js"></script>
<script src="http://localhost:8080/ml5.js" type="text/javascript"></script>
</head>

<body>
<h1>Universal Sentence Encoder</h1>
<script src="sketch.js"></script>
</body>
</html>
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
let sentenceEncoder;
const sentence = 'Monday, Tuesday, Wednesday, Thursday, and Friday are days of the Week';


function setup(){
createCanvas(512, 512);
// background(220);
colorMode(HSB, 360, 100, 100);
rectMode(CENTER);
textAlign(CENTER);
sentenceEncoder = ml5.universalSentenceEncoder({withTokenizer:true}, modelLoaded)
}

function modelLoaded(){
console.log('model ready')
predict();
}

function predict(){
console.log('predicting')
sentenceEncoder.encode(sentence, gotResults);
}

function gotResults(err, result){
if(err){
return err;
}
console.log(result);
translate(40, 0);
result.forEach( (item, idx) => {
const rectHeight = map(item, 0, 7999, 0, 100);
fill(180, 100, 100);
rect(idx * 20, height/2 , 20, rectHeight);
})


}
5 changes: 5 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@
"transform-object-rest-spread"
]
},
"peerDependencies": {
"@tensorflow/tfjs-core": "^1.2.9",
"@tensorflow/tfjs-converter": "^1.2.9"
},
"dependencies": {
"@magenta/sketch": "0.2.0",
"@tensorflow-models/body-pix": "1.1.2",
Expand All @@ -112,7 +116,8 @@
"@tensorflow-models/mobilenet": "2.0.3",
"@tensorflow-models/posenet": "2.1.3",
"@tensorflow-models/speech-commands": "0.3.9",
"@tensorflow/tfjs": "1.7.0",
"@tensorflow-models/universal-sentence-encoder": "^1.2.2",
"@tensorflow/tfjs": "^1.7.0",
"@tensorflow/tfjs-vis": "^1.1.0",
"events": "^3.0.0",
"face-api.js": "~0.22.2",
Expand Down
78 changes: 78 additions & 0 deletions src/UniversalSentenceEncoder/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// import * as tf from '@tensorflow/tfjs';
import * as USE from '@tensorflow-models/universal-sentence-encoder';
import callCallback from '../utils/callcallback';

const DEFAULTS = {
withTokenizer: false,
}

class UniversalSentenceEncoder {
constructor(options, callback){
this.model = null;
this.tokenizer = null;
this.config = {
withTokenizer: options.withTokenizer || DEFAULTS.withTokenizer
};

callCallback(this.loadModel(), callback);
}

/**
* load model
*/
async loadModel(){
if(this.config.withTokenizer === true){
this.tokenizer = await USE.loadTokenizer();
}
this.model = await USE.load();
return this;
}

/**
* Encodes a string or array based on the USE
* @param {*} textString
* @param {*} callback
*/
predict(textArray, callback){
return callCallback(this.predictInternal(textArray), callback);
}

async predictInternal(textArray){
try{
const embeddings = await this.model.embed(textArray);
const results = await embeddings.array();
embeddings.dispose();
return results;
} catch(err){
console.error(err);
return err;
}
}

/**
* Encodes a string based on the loaded tokenizer if the withTokenizer:true
* @param {*} textString
* @param {*} callback
*/
encode(textString, callback){
return callCallback(this.encodeInternal(textString), callback);
}

async encodeInternal(textString){
if(this.config.withTokenizer === true){
return this.tokenizer.encode(textString);
}
console.error('withTokenizer must be set to true - please pass "withTokenizer:true" as an option in the constructor');
return false;
}

}

const universalSentenceEncoder = (optionsOr, cb) => {
const options = (typeof optionsOr === 'object') ? optionsOr : {};
const callback = (typeof optionsOr === 'function') ? optionsOr : cb;

return new UniversalSentenceEncoder(options, callback);
};

export default universalSentenceEncoder;
Empty file.
2 changes: 2 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import bodyPix from './BodyPix';
import neuralNetwork from './NeuralNetwork';
import faceApi from './FaceApi';
import kmeans from './KMeans';
import universalSentenceEncoder from './UniversalSentenceEncoder';
import p5Utils from './utils/p5Utils';
import communityStatement from './utils/community';

Expand All @@ -52,6 +53,7 @@ const withPreload = {
sentiment,
bodyPix,
faceApi,
universalSentenceEncoder
};

// call community statement on load
Expand Down

0 comments on commit dad91c9

Please sign in to comment.