diff --git a/tfjs-converter/src/executor/graph_executor.ts b/tfjs-converter/src/executor/graph_executor.ts index 67db57ef901..7943744fb56 100644 --- a/tfjs-converter/src/executor/graph_executor.ts +++ b/tfjs-converter/src/executor/graph_executor.ts @@ -33,7 +33,7 @@ interface NodeWithContexts { node: Node; } -// Default, not dump. Sync, dump by execute. Async, dump by executeAsync. +// Default, dump is off. Sync, dump by execute. Async, dump by executeAsync. enum DumpMode { Default = -1, Sync, @@ -52,11 +52,10 @@ export class GraphExecutor implements FunctionExecutor { private _functions: {[key: string]: Graph} = {}; private _functionExecutorMap: {[key: string]: FunctionExecutor} = {}; private _resourceManager: ResourceManager; - private intermediateTensors: NamedTensorsMap = {}; - private keepIdsForExecuteAsync: Set; + // Variables with Async suffix is used for dumping by executeAsync. + private idsKeepForAsync: Set; + private tensorsPendingDisposal: Tensor[]; private tensorsMap: NamedTensorsMap; - private keepInputTensorsForExecute: Tensor[]; - private keepTensorsForExecute: Tensor[]; private dumpMode = DumpMode.Default; get weightIds(): number[] { @@ -191,13 +190,16 @@ export class GraphExecutor implements FunctionExecutor { this.graph, this.weightMap, executionInfo); } - private keepTensors(keepTensors: Tensor[], tensors: Tensor[]) { - if (this.dumpMode !== DumpMode.Sync || tensors == null) { + private keepTensors( + tensorsToKeep: Tensor[], tensorsPendingDisposal: Tensor[] = null) { + if (this.dumpMode !== DumpMode.Sync || tensorsToKeep == null) { return; } - tensors.forEach(tensor => { + tensorsToKeep.forEach(tensor => { if (tensor && !tensor.kept) { - keepTensors.push(tensor); + if (tensorsPendingDisposal) { + tensorsPendingDisposal.push(tensor); + } keep(tensor); } }); @@ -254,8 +256,7 @@ export class GraphExecutor implements FunctionExecutor { this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap); if (this.dumpMode === DumpMode.Sync) { - this.keepTensorsForExecute = []; - this.keepInputTensorsForExecute = []; + this.tensorsPendingDisposal = []; } Object.keys(inputs).forEach(name => { @@ -264,7 +265,7 @@ export class GraphExecutor implements FunctionExecutor { tensors[index] = inputs[name]; tensorsMap[nodeName] = tensors; // Input tensors should be disposed by user. - this.keepTensors(this.keepInputTensorsForExecute, tensors); + this.keepTensors(tensors); }); const tensorsToKeep = this.getFrozenTensorIds(tensorsMap); @@ -281,7 +282,7 @@ export class GraphExecutor implements FunctionExecutor { `Please use model.executeAsync() instead.`); } tensorsMap[node.name] = tensors; - this.keepTensors(this.keepTensorsForExecute, tensors); + this.keepTensors(tensors, this.tensorsPendingDisposal); this.checkTensorForDisposal( node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount); @@ -297,7 +298,7 @@ export class GraphExecutor implements FunctionExecutor { if (this.dumpMode === DumpMode.Sync) { this.tensorsMap = tensorsMap; } else { - this.tensorsMap = {}; + this.tensorsMap = null; } return result; } @@ -342,15 +343,8 @@ export class GraphExecutor implements FunctionExecutor { if (count === 1) { if (this.dumpMode === DumpMode.Default) { tensor.dispose(); - } else { - const [nodeName, index] = - getNodeNameAndIndex(node.name, context); - if (this.intermediateTensors[nodeName]) { - this.intermediateTensors[nodeName][index] = tensor; - } else { - this.intermediateTensors[nodeName] = []; - this.intermediateTensors[nodeName][index] = tensor; - } + } else if (this.dumpMode === DumpMode.Async) { + this.tensorsPendingDisposal.push(tensor); } delete intermediateTensorConsumerCount[tensor.id]; } else if (count != null) { @@ -380,20 +374,19 @@ export class GraphExecutor implements FunctionExecutor { } disposeIntermediateTensors() { - for (const key in this.intermediateTensors) { - this.intermediateTensors[key].forEach(tensor => tensor.dispose()); + if (this.dumpMode === DumpMode.Default) { + return; } - this.intermediateTensors = {}; - if (this.dumpMode === DumpMode.Sync) { - if (this.keepTensorsForExecute) { - this.keepTensorsForExecute.forEach(tensor => { - tensor.dispose(); - }); - this.keepTensorsForExecute = null; - } - } else if (this.dumpMode === DumpMode.Async) { + if (this.tensorsPendingDisposal) { + this.tensorsPendingDisposal.forEach(tensor => { + tensor.dispose(); + }); + this.tensorsPendingDisposal = null; + } + if (this.dumpMode === DumpMode.Async) { this.disposeTensorsMap(); } + this.tensorsMap = null; this.dumpMode = DumpMode.Default; } @@ -406,7 +399,7 @@ export class GraphExecutor implements FunctionExecutor { const tensorArray = this.tensorsMap[key]; tensorArray.forEach(tensor => { if (tensor && !tensor.kept && !tensor.isDisposed && - !this.keepIdsForExecuteAsync.has(tensor.id)) { + !this.idsKeepForAsync.has(tensor.id)) { tensor.dispose(); } }); @@ -448,6 +441,7 @@ export class GraphExecutor implements FunctionExecutor { const keepTensorForDump = env().getBool('KEEP_INTERMEDIATE_TENSORS'); if (keepTensorForDump) { this.dumpMode = DumpMode.Async; + this.tensorsPendingDisposal = []; } } catch (e) { console.warn(e.message); @@ -468,7 +462,7 @@ export class GraphExecutor implements FunctionExecutor { // dispose all the intermediate tensors const outputIds = results.map(t => t.id); const inputIds = Object.keys(inputs).map(name => inputs[name].id); - this.keepIdsForExecuteAsync = + this.idsKeepForAsync = new Set([...outputIds, ...inputIds, ...this.weightIds]); if (this.dumpMode !== DumpMode.Async) { this.disposeTensorsMap(); @@ -476,7 +470,7 @@ export class GraphExecutor implements FunctionExecutor { // dispose the context for the root executor if (this.parent == null) { - context.dispose(this.keepIdsForExecuteAsync); + context.dispose(this.idsKeepForAsync); } return results;