Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add null check for signature inputs outputs #6978

Merged
merged 2 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions tfjs-converter/src/executor/graph_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,8 @@ export class GraphExecutor implements FunctionExecutor {
private mapInputs(inputs: NamedTensorMap) {
const result: NamedTensorMap = {};
for (const inputName in inputs) {
if (this._signature != null && this._signature.inputs != null &&
this._signature.inputs[inputName] != null) {
const tensor = this._signature.inputs[inputName];
const tensor = this._signature?.inputs?.[inputName];
if (tensor != null) {
result[tensor.name] = inputs[inputName];
} else {
result[inputName] = inputs[inputName];
Expand All @@ -670,9 +669,8 @@ export class GraphExecutor implements FunctionExecutor {

private mapOutputs(outputs: string[]) {
return outputs.map(name => {
if (this._signature != null && this._signature.outputs != null &&
this._signature.outputs[name] != null) {
const tensor = this._signature.outputs[name];
const tensor = this._signature?.outputs?.[name];
if (tensor != null) {
return tensor.name;
}
return name;
Expand Down
19 changes: 10 additions & 9 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,10 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
NamedTensorMap): NamedTensorMap {
if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
// The input is already a NamedTensorMap.
if (this.signature != null && this.signature.inputs != null) {
for (const input in this.signature.inputs) {
const tensor = this.signature.inputs[input];
const signatureInputs = this.signature?.inputs;
if (signatureInputs != null) {
for (const input in signatureInputs) {
const tensor = signatureInputs[input];
if (tensor.resourceId != null) {
inputs[input] = this.resourceIdToCapturedInput[tensor.resourceId];
}
Expand All @@ -410,10 +411,9 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements

let inputIndex = 0;
return this.inputNodes.reduce((map, inputName) => {
const signature =
this.signature ? this.signature.inputs[inputName] : null;
if (signature != null && signature.resourceId != null) {
map[inputName] = this.resourceIdToCapturedInput[signature.resourceId];
const resourceId = this.signature?.inputs?.[inputName]?.resourceId;
if (resourceId != null) {
map[inputName] = this.resourceIdToCapturedInput[resourceId];
} else {
map[inputName] = (inputs as Tensor[])[inputIndex++];
}
Expand Down Expand Up @@ -454,10 +454,11 @@ export class GraphModel<ModelURL extends Url = string | io.IOHandler> implements
this.resourceIdToCapturedInput = {};

if (this.initializerSignature) {
const outputNames = Object.keys(this.initializerSignature.outputs);
const signatureOutputs = this.initializerSignature.outputs;
const outputNames = Object.keys(signatureOutputs);
for (let i = 0; i < outputNames.length; i++) {
const outputName = outputNames[i];
const tensorInfo = this.initializerSignature.outputs[outputName];
const tensorInfo = signatureOutputs[outputName];
this.resourceIdToCapturedInput[tensorInfo.resourceId] = outputs[i];
}
}
Expand Down
66 changes: 56 additions & 10 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ const SIMPLE_HTTP_MODEL_LOADER = {
}
};

const NO_INPUT_SIGNATURE_MODEL_LOADER = {
load: async () => {
return {
modelTopology: SIMPLE_MODEL,
weightSpecs: weightsManifest,
weightData: bias.dataSync(),
format: 'tfjs-graph-model',
generatedBy: '1.15',
convertedBy: '1.3.1',
userDefinedMetadata: {signature: {outputs: SIGNATURE.outputs}}
};
}
};

const CUSTOM_OP_MODEL: tensorflow.IGraphDef = {
node: [
{
Expand Down Expand Up @@ -479,10 +493,11 @@ describe('loadGraphModelSync', () => {
weightsManifest: [{paths: [], weights: weightsManifest}],
};
expect(() => {
return loadGraphModelSync([modelJson] as unknown as [io.ModelJSON,
ArrayBuffer]);
}).toThrowMatching(err =>
err.message.includes('weights must be the second element'));
return loadGraphModelSync(
[modelJson] as unknown as [io.ModelJSON, ArrayBuffer]);
})
.toThrowMatching(
err => err.message.includes('weights must be the second element'));
});

it('Throws an error if modelJSON is missing \'modelTopology\'', () => {
Expand All @@ -492,8 +507,9 @@ describe('loadGraphModelSync', () => {
const weights = new Int32Array([5]).buffer;
expect(() => {
return loadGraphModelSync([badInput as io.ModelJSON, weights]);
}).toThrowMatching(err =>
err.message.includes('missing \'modelTopology\''));
})
.toThrowMatching(
err => err.message.includes('missing \'modelTopology\''));
});

it('Throws an error if modelJSON is missing \'weightsManifest\'', () => {
Expand All @@ -503,16 +519,16 @@ describe('loadGraphModelSync', () => {
const weights = new Int32Array([5]).buffer;
expect(() => {
return loadGraphModelSync([badInput as io.ModelJSON, weights]);
}).toThrowMatching(err =>
err.message.includes('missing \'weightsManifest\''));
})
.toThrowMatching(
err => err.message.includes('missing \'weightsManifest\''));
});

it('Throws an error if modelSource is an unknown format', () => {
const badInput = {foo: 'bar'};
expect(() => {
return loadGraphModelSync(badInput as io.ModelArtifacts);
}).toThrowMatching(err =>
err.message.includes('Unknown model format'));
}).toThrowMatching(err => err.message.includes('Unknown model format'));
});

it('Expect an error when moderUrl is null', () => {
Expand Down Expand Up @@ -774,6 +790,36 @@ describe('Model', () => {
});
});

describe('no signature input model', () => {
beforeEach(() => {
spyIo.getLoadHandlers.and.returnValue([NO_INPUT_SIGNATURE_MODEL_LOADER]);
spyIo.browserHTTPRequest.and.returnValue(NO_INPUT_SIGNATURE_MODEL_LOADER);
});

it('load', async () => {
const loaded = await model.load();
expect(loaded).toBe(true);
});

describe('predict', () => {
it('should generate default output', async () => {
await model.load();
const input = tfc.tensor2d([1, 1], [2, 1], 'int32');
const output = model.execute({'Input': input});
expect((output as tfc.Tensor).dataSync()[0]).toEqual(3);
});
});

describe('execute', () => {
it('should generate default output', async () => {
await model.load();
const input = tfc.tensor2d([1, 1], [2, 1], 'int32');
const output = model.execute(input);
expect((output as tfc.Tensor).dataSync()[0]).toEqual(3);
});
});
});

describe('structured outputs model', () => {
beforeEach(() => {
spyIo.getLoadHandlers.and.returnValue([STRUCTURED_OUTPUTS_MODEL_LOADER]);
Expand Down