Skip to content

Commit

Permalink
[converter] Avoid spying on properties of the 'tfOps' module (#6563)
Browse files Browse the repository at this point in the history
According to the ESModule standard, properties of modules are immutable. TypeScript 4 will enforce this rule. This in particular affects tests for executors in tfjs-converter, in which we often spy on tfOps.

This PR removes all instances of spyOn(tfOps,...) and replaces them with a separate spyOps mock / fake which is passed to the executeOp function.

It also removes spying on the io module in graph_model_test.ts.

This PR was part of the larger TS4 upgrade PR (#6346, #5561), but I'm splitting that PR into pieces that can be merged while we're still using TS3 because it's too large to keep up-to-date.

This PR also bumps lib to es2019 in the root tsconfig to allow using Object.fromEntries. This shouldn't affect the code we ship since it's still compiled to the es2017 target.

Note that this PR does not bump TypeScript to version 4. It leaves it at 3.
  • Loading branch information
mattsoulanille authored Jun 29, 2022
1 parent b83e5b4 commit ccca854
Show file tree
Hide file tree
Showing 45 changed files with 1,147 additions and 726 deletions.
3 changes: 2 additions & 1 deletion tfjs-converter/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"opn": "~5.1.0",
"protobufjs": "~6.11.3",
"ts-node": "~8.8.2",
"typescript": "3.5.3"
"typescript": "3.5.3",
"yalc": "~1.0.0-pre.50"
},
"scripts": {
"build": "bazel build :tfjs-converter_pkg",
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/scripts/kernels_to_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,16 @@ function getKernelMappingForFile(source: SourceFile) {
const callExprs =
clausePart.getDescendantsOfKind(SyntaxKind.CallExpression);
const tfOpsCallExprs =
callExprs.filter(expr => expr.getText().match(/tfOps/));
callExprs.filter(expr => expr.getText().match(/ops/));
const tfSymbols: Set<string> = new Set();
for (const tfOpsCall of tfOpsCallExprs) {
const tfOpsCallStr = tfOpsCall.getText();
const functionCallMatcher = /(tfOps\.([\w\.]*)\()/g;
const functionCallMatcher = /(ops\.([\w\.]*)\()/g;
const matches = tfOpsCallStr.match(functionCallMatcher);
if (matches != null && matches.length > 0) {
for (const match of matches) {
// extract the method name (and any namespaces used to call it)
const symbolMatcher = /(tfOps\.([\w\.]*)\()/;
const symbolMatcher = /(ops\.([\w\.]*)\()/;
const symbol = match.match(symbolMatcher)[2];
tfSymbols.add(symbol);
}
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package(default_visibility = ["//visibility:public"])
TEST_SRCS = [
"**/*_test.ts",
"run_tests.ts",
"operations/executors/spy_ops.ts",
]

# Used for test-snippets
Expand Down
26 changes: 15 additions & 11 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
private initializer: GraphExecutor;
private resourceManager: ResourceManager;
private signature: tensorflow.ISignatureDef;
private readonly io: typeof io;

// Returns the version information for the tensorflow model GraphDef.
get modelVersion(): string {
Expand Down Expand Up @@ -93,7 +94,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
*/
constructor(
private modelUrl: ModelURL,
private loadOptions: io.LoadOptions = {}) {
private loadOptions: io.LoadOptions = {}, tfio = io) {
this.io = tfio;
if (loadOptions == null) {
this.loadOptions = {};
}
Expand All @@ -107,14 +109,16 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
// Path is an IO Handler.
this.handler = path as IOHandler;
} else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path as string, this.loadOptions) as
IOHandler;
this.handler = this.io
.browserHTTPRequest(path as string, this.loadOptions) as IOHandler;
} else {
const handlers = io.getLoadHandlers(path as string, this.loadOptions);
const handlers =
this.io.getLoadHandlers(path as string, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));
handlers.push(
this.io.browserHTTPRequest(path as string, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error(
`Found more than one (${handlers.length}) load handlers for ` +
Expand Down Expand Up @@ -171,8 +175,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
this.signature = signature;

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap =
io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
const weightMap = this.io.decodeWeights(
this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down Expand Up @@ -243,7 +247,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):
Promise<io.SaveResult> {
if (typeof handlerOrURL === 'string') {
const handlers = io.getSaveHandlers(handlerOrURL);
const handlers = this.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error(
`Cannot find any save handlers for URL '${handlerOrURL}'`);
Expand Down Expand Up @@ -452,8 +456,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
* @doc {heading: 'Models', subheading: 'Loading'}
*/
export async function loadGraphModel(
modelUrl: string|io.IOHandler,
options: io.LoadOptions = {}): Promise<GraphModel> {
modelUrl: string|io.IOHandler, options: io.LoadOptions = {},
tfio = io): Promise<GraphModel> {
if (modelUrl == null) {
throw new Error(
'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
Expand All @@ -466,7 +470,7 @@ export async function loadGraphModel(
if (options.fromTFHub && typeof modelUrl === 'string') {
modelUrl = getTFHubUrl(modelUrl);
}
const model = new GraphModel(modelUrl, options);
const model = new GraphModel(modelUrl, options, tfio);
await model.load();
return model;
}
Expand Down
57 changes: 31 additions & 26 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {deregisterOp, registerOp} from '../operations/custom_op/register';
import {GraphNode} from '../operations/types';

import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -368,6 +369,12 @@ describe('loadSync', () => {
});

describe('loadGraphModel', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
spyIo = spyOnAllFunctions(io);
});

it('Pass a custom io handler', async () => {
const customLoader: tfc.io.IOHandler = {
load: async () => {
Expand Down Expand Up @@ -397,11 +404,11 @@ describe('loadGraphModel', () => {

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
await loadGraphModel(MODEL_URL, {fetchFunc});
expect(tfc.io.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
});
});

Expand Down Expand Up @@ -436,13 +443,16 @@ describe('loadGraphModelSync', () => {
});

describe('Model', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
model = new GraphModel(MODEL_URL);
spyIo = spyOnAllFunctions(io);
model = new GraphModel(MODEL_URL, undefined, spyIo);
});

describe('custom model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
registerOp('CustomOp', (nodeValue: GraphNode) => {
Expand Down Expand Up @@ -484,11 +494,10 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
SIMPLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
const loaded = await model.load();
Expand Down Expand Up @@ -621,7 +630,7 @@ describe('Model', () => {
describe('dispose', () => {
it('should dispose the weights', async () => {
const numOfTensors = tfc.memory().numTensors;
model = new GraphModel(MODEL_URL);
model = new GraphModel(MODEL_URL, undefined, spyIo);

await model.load();
model.dispose();
Expand All @@ -639,7 +648,7 @@ describe('Model', () => {

describe('relative path', () => {
beforeEach(() => {
model = new GraphModel(RELATIVE_MODEL_URL);
model = new GraphModel(RELATIVE_MODEL_URL, undefined, spyIo);
});

it('load', async () => {
Expand All @@ -649,22 +658,22 @@ describe('Model', () => {
});

it('should loadGraphModel', async () => {
const model = await loadGraphModel(MODEL_URL);
const model = await loadGraphModel(MODEL_URL, undefined, spyIo);
expect(model).not.toBeUndefined();
});

it('should loadGraphModel with request options', async () => {
const model = await loadGraphModel(
MODEL_URL, {requestInit: {credentials: 'include'}});
expect(tfc.io.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
MODEL_URL, {requestInit: {credentials: 'include'}}, spyIo);
expect(spyIo.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
requestInit: {credentials: 'include'}
});
expect(model).not.toBeUndefined();
});

it('should call loadGraphModel for TfHub Module', async () => {
const url = `${HOST}/model/1`;
const model = await loadGraphModel(url, {fromTFHub: true});
const model = await loadGraphModel(url, {fromTFHub: true}, spyIo);
expect(model).toBeDefined();
});

Expand All @@ -686,11 +695,10 @@ describe('Model', () => {

describe('control flow model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CONTROL_FLOW_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
});

describe('save', () => {
Expand Down Expand Up @@ -777,11 +785,10 @@ describe('Model', () => {
};
describe('dynamic shape model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
});

it('should throw error if call predict directly', async () => {
Expand Down Expand Up @@ -822,11 +829,10 @@ describe('Model', () => {
});
describe('dynamic shape model with metadata', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_NEW_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
});

it('should be success if call executeAsync with signature key',
Expand All @@ -848,11 +854,10 @@ describe('Model', () => {

describe('Hashtable model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
HASHTABLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
});
it('should be successful if call executeAsync', async () => {
await model.load();
Expand Down
26 changes: 13 additions & 13 deletions tfjs-converter/src/operations/executors/arithmetic_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,66 +27,66 @@ import {getParamValue} from './utils';

export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Tensor[] => {
context: ExecutionContext, ops = tfOps): Tensor[] => {
switch (node.op) {
case 'BiasAdd':
case 'AddV2':
case 'Add': {
return [tfOps.add(
return [ops.add(
(getParamValue('a', node, tensorMap, context) as Tensor),
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'AddN': {
return [tfOps.addN((
return [ops.addN((
getParamValue('tensors', node, tensorMap, context) as Tensor[]))];
}
case 'FloorMod':
case 'Mod':
return [tfOps.mod(
return [ops.mod(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'Mul':
return [tfOps.mul(
return [ops.mul(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'RealDiv':
case 'Div': {
return [tfOps.div(
return [ops.div(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'DivNoNan': {
return [tfOps.divNoNan(
return [ops.divNoNan(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'FloorDiv': {
return [tfOps.floorDiv(
return [ops.floorDiv(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Sub': {
return [tfOps.sub(
return [ops.sub(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Minimum': {
return [tfOps.minimum(
return [ops.minimum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Maximum': {
return [tfOps.maximum(
return [ops.maximum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Pow': {
return [tfOps.pow(
return [ops.pow(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'SquaredDifference': {
return [tfOps.squaredDifference(
return [ops.squaredDifference(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
Expand Down
Loading

0 comments on commit ccca854

Please sign in to comment.