Skip to content

Commit

Permalink
Support for json graphdef objects
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilk committed Jan 10, 2018
1 parent 5b2a165 commit ecd28d7
Show file tree
Hide file tree
Showing 10 changed files with 827 additions and 20 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"main": "src/index.js",
"dependencies": {
"ffi": "^2.2.0",
"protocol-buffers": "^3.2.1",
"pbf": "^3.1.0",
"ref": "^1.3.5",
"ref-array": "^1.2.0",
"ref-struct": "^1.1.0"
Expand Down
65 changes: 65 additions & 0 deletions samples/graphs/json/main.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
const tf = require('tensorflow');

var const1 = {
name: 'c1',
op: 'Const',
attr: {
value: {
value: 'tensor',
tensor: {
dtype: 3,
tensor_shape: { dim: [] },
int_val: [1]
}
},
dtype: {
value: 'type',
type: 3
}
}
};

var const2 = {
name: 'c2',
op: 'Const',
attr: {
value: {
value: 'tensor',
tensor: {
dtype: 3,
tensor_shape: { dim: [] },
int_val: [41]
}
},
dtype: {
value: 'type',
type: 3
}
}
};

var add = {
name: 'sum',
op: 'Add',
input: [
'c1',
'c2'
],
attr: {
T: {
value: 'type',
type: 3
}
}
};

var graphDef = {
node: [ const1, const2, add ]
}

let session = tf.Session.fromGraphDef(graphDef, { sum: 'sum' });
let results = session.run(null, ['sum'], null);

console.log(results.sum.toValue());
results.sum.delete();
session.delete();
8 changes: 8 additions & 0 deletions samples/graphs/json/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"dependencies": {
"tensorflow": "^0.6.5"
},
"scripts": {
"sample": "node main.js"
}
}
28 changes: 18 additions & 10 deletions src/graph.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

'use strict';

const api = require('./api'),
const api = require('./interop/api'),
fs = require('fs');

function loadGraph(protobuf) {
Expand Down Expand Up @@ -61,8 +61,23 @@ class Graph extends api.Reference {
return unresolvedOps;
}

static fromGraphDef(graphDefPath, operations) {
let protobuf = fs.readFileSync(graphDefPath);
static fromGraphDef(graphDef, operations) {
let protobuf = null;
if (graphDef.constructor == String) {
protobuf = fs.readFileSync(graphDef);
}
else if (Buffer.isBuffer(graphDef)) {
protobuf = graphDef;
}
else {
let ProtobufWriter = require('pbf');

let writer = new ProtobufWriter();
api.Protos.GraphDef.write(graphDef, writer);

protobuf = writer.finish();
}

let graph = loadGraph(protobuf);

if (operations) {
Expand All @@ -71,13 +86,6 @@ class Graph extends api.Reference {

return graph;
}

// TODO: Implement loading a Graph from an in-memory JSON object representatio of a GraphDef.
// However, this doesn't yet work. The resulting protobuf/GraphDef is invalid.
// static fromGraphDefObject(graphDefObject) {
// let protobuf = api.Protos.GraphDef.encode(graphDefObject);
// return loadGraph(protobuf);
// }
}


Expand Down
2 changes: 1 addition & 1 deletion src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Defines the TensorFlow module.
//

const api = require('./api');
const api = require('./interop/api');

module.exports = {
Types: api.Types,
Expand Down
6 changes: 3 additions & 3 deletions src/api.js → src/interop/api.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// api.js
// tf.js
// Interface for TensorFlow C API
//
// This defines the TensorFlow library matching a subset of the C API methods as defined in
Expand Down Expand Up @@ -91,7 +91,7 @@ const statusCodes = {

let libPath = process.env['TENSORFLOW_LIB_PATH'];
if (!libPath) {
libPath = path.join(__dirname, '..', 'lib');
libPath = path.join(__dirname, '..', '..', 'lib');
}
if (!fs.existsSync(path.join(libPath, 'libtensorflow.so'))) {
throw new Error(`libtensorflow.so was not found at "${libPath}"`);
Expand Down Expand Up @@ -191,7 +191,7 @@ const libApi = {
};

const library = ffi.Library(path.join(libPath, 'libtensorflow'), libApi);
library.Protos = protobuf(fs.readFileSync(path.join(__dirname, 'api.proto')));
library.Protos = require('./messages');
library.ApiTypes = types;
library.StatusCodes = statusCodes;
library.Types = tensorTypes;
Expand Down
Loading

0 comments on commit ecd28d7

Please sign in to comment.