diff --git a/tfjs-converter/src/executor/graph_executor.ts b/tfjs-converter/src/executor/graph_executor.ts index 845b653ee99..cf998a38af5 100644 --- a/tfjs-converter/src/executor/graph_executor.ts +++ b/tfjs-converter/src/executor/graph_executor.ts @@ -645,9 +645,8 @@ export class GraphExecutor implements FunctionExecutor { private mapInputs(inputs: NamedTensorMap) { const result: NamedTensorMap = {}; for (const inputName in inputs) { - if (this._signature != null && this._signature.inputs != null && - this._signature.inputs[inputName] != null) { - const tensor = this._signature.inputs[inputName]; + const tensor = this._signature?.inputs?.[inputName]; + if (tensor != null) { result[tensor.name] = inputs[inputName]; } else { result[inputName] = inputs[inputName]; @@ -670,9 +669,8 @@ export class GraphExecutor implements FunctionExecutor { private mapOutputs(outputs: string[]) { return outputs.map(name => { - if (this._signature != null && this._signature.outputs != null && - this._signature.outputs[name] != null) { - const tensor = this._signature.outputs[name]; + const tensor = this._signature?.outputs?.[name]; + if (tensor != null) { return tensor.name; } return name; diff --git a/tfjs-converter/src/executor/graph_model.ts b/tfjs-converter/src/executor/graph_model.ts index 666caf07685..ef9c8ad57a1 100644 --- a/tfjs-converter/src/executor/graph_model.ts +++ b/tfjs-converter/src/executor/graph_model.ts @@ -387,9 +387,10 @@ export class GraphModel implements NamedTensorMap): NamedTensorMap { if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) { // The input is already a NamedTensorMap. - if (this.signature != null && this.signature.inputs != null) { - for (const input in this.signature.inputs) { - const tensor = this.signature.inputs[input]; + const signatureInputs = this.signature?.inputs; + if (signatureInputs != null) { + for (const input in signatureInputs) { + const tensor = signatureInputs[input]; if (tensor.resourceId != null) { inputs[input] = this.resourceIdToCapturedInput[tensor.resourceId]; } @@ -410,10 +411,9 @@ export class GraphModel implements let inputIndex = 0; return this.inputNodes.reduce((map, inputName) => { - const signature = - this.signature ? this.signature.inputs[inputName] : null; - if (signature != null && signature.resourceId != null) { - map[inputName] = this.resourceIdToCapturedInput[signature.resourceId]; + const resourceId = this.signature?.inputs?.[inputName]?.resourceId; + if (resourceId != null) { + map[inputName] = this.resourceIdToCapturedInput[resourceId]; } else { map[inputName] = (inputs as Tensor[])[inputIndex++]; } @@ -454,10 +454,11 @@ export class GraphModel implements this.resourceIdToCapturedInput = {}; if (this.initializerSignature) { - const outputNames = Object.keys(this.initializerSignature.outputs); + const signatureOutputs = this.initializerSignature.outputs; + const outputNames = Object.keys(signatureOutputs); for (let i = 0; i < outputNames.length; i++) { const outputName = outputNames[i]; - const tensorInfo = this.initializerSignature.outputs[outputName]; + const tensorInfo = signatureOutputs[outputName]; this.resourceIdToCapturedInput[tensorInfo.resourceId] = outputs[i]; } } diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 9800f4b62ac..7d3a3ade14f 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -125,6 +125,20 @@ const SIMPLE_HTTP_MODEL_LOADER = { } }; +const NO_INPUT_SIGNATURE_MODEL_LOADER = { + load: async () => { + return { + modelTopology: SIMPLE_MODEL, + weightSpecs: weightsManifest, + weightData: bias.dataSync(), + format: 'tfjs-graph-model', + generatedBy: '1.15', + convertedBy: '1.3.1', + userDefinedMetadata: {signature: {outputs: SIGNATURE.outputs}} + }; + } +}; + const CUSTOM_OP_MODEL: tensorflow.IGraphDef = { node: [ { @@ -479,10 +493,11 @@ describe('loadGraphModelSync', () => { weightsManifest: [{paths: [], weights: weightsManifest}], }; expect(() => { - return loadGraphModelSync([modelJson] as unknown as [io.ModelJSON, - ArrayBuffer]); - }).toThrowMatching(err => - err.message.includes('weights must be the second element')); + return loadGraphModelSync( + [modelJson] as unknown as [io.ModelJSON, ArrayBuffer]); + }) + .toThrowMatching( + err => err.message.includes('weights must be the second element')); }); it('Throws an error if modelJSON is missing \'modelTopology\'', () => { @@ -492,8 +507,9 @@ describe('loadGraphModelSync', () => { const weights = new Int32Array([5]).buffer; expect(() => { return loadGraphModelSync([badInput as io.ModelJSON, weights]); - }).toThrowMatching(err => - err.message.includes('missing \'modelTopology\'')); + }) + .toThrowMatching( + err => err.message.includes('missing \'modelTopology\'')); }); it('Throws an error if modelJSON is missing \'weightsManifest\'', () => { @@ -503,16 +519,16 @@ describe('loadGraphModelSync', () => { const weights = new Int32Array([5]).buffer; expect(() => { return loadGraphModelSync([badInput as io.ModelJSON, weights]); - }).toThrowMatching(err => - err.message.includes('missing \'weightsManifest\'')); + }) + .toThrowMatching( + err => err.message.includes('missing \'weightsManifest\'')); }); it('Throws an error if modelSource is an unknown format', () => { const badInput = {foo: 'bar'}; expect(() => { return loadGraphModelSync(badInput as io.ModelArtifacts); - }).toThrowMatching(err => - err.message.includes('Unknown model format')); + }).toThrowMatching(err => err.message.includes('Unknown model format')); }); it('Expect an error when moderUrl is null', () => { @@ -774,6 +790,36 @@ describe('Model', () => { }); }); + describe('no signature input model', () => { + beforeEach(() => { + spyIo.getLoadHandlers.and.returnValue([NO_INPUT_SIGNATURE_MODEL_LOADER]); + spyIo.browserHTTPRequest.and.returnValue(NO_INPUT_SIGNATURE_MODEL_LOADER); + }); + + it('load', async () => { + const loaded = await model.load(); + expect(loaded).toBe(true); + }); + + describe('predict', () => { + it('should generate default output', async () => { + await model.load(); + const input = tfc.tensor2d([1, 1], [2, 1], 'int32'); + const output = model.execute({'Input': input}); + expect((output as tfc.Tensor).dataSync()[0]).toEqual(3); + }); + }); + + describe('execute', () => { + it('should generate default output', async () => { + await model.load(); + const input = tfc.tensor2d([1, 1], [2, 1], 'int32'); + const output = model.execute(input); + expect((output as tfc.Tensor).dataSync()[0]).toEqual(3); + }); + }); + }); + describe('structured outputs model', () => { beforeEach(() => { spyIo.getLoadHandlers.and.returnValue([STRUCTURED_OUTPUTS_MODEL_LOADER]);