Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Oct 26, 2022
1 parent 37b9489 commit ca5dbc9
Showing 1 changed file with 31 additions and 37 deletions.
68 changes: 31 additions & 37 deletions tfjs-converter/src/executor/graph_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ interface NodeWithContexts {
node: Node;
}

// Default, not dump. Sync, dump by execute. Async, dump by executeAsync.
// Default, dump is off. Sync, dump by execute. Async, dump by executeAsync.
enum DumpMode {
Default = -1,
Sync,
Expand All @@ -52,11 +52,10 @@ export class GraphExecutor implements FunctionExecutor {
private _functions: {[key: string]: Graph} = {};
private _functionExecutorMap: {[key: string]: FunctionExecutor} = {};
private _resourceManager: ResourceManager;
private intermediateTensors: NamedTensorsMap = {};
private keepIdsForExecuteAsync: Set<number>;
// Variables with Async suffix is used for dumping by executeAsync.
private idsKeepForAsync: Set<number>;
private tensorsPendingDisposal: Tensor[];
private tensorsMap: NamedTensorsMap;
private keepInputTensorsForExecute: Tensor[];
private keepTensorsForExecute: Tensor[];
private dumpMode = DumpMode.Default;

get weightIds(): number[] {
Expand Down Expand Up @@ -191,13 +190,16 @@ export class GraphExecutor implements FunctionExecutor {
this.graph, this.weightMap, executionInfo);
}

private keepTensors(keepTensors: Tensor[], tensors: Tensor[]) {
if (this.dumpMode !== DumpMode.Sync || tensors == null) {
private keepTensors(
tensorsToKeep: Tensor[], tensorsPendingDisposal: Tensor[] = null) {
if (this.dumpMode !== DumpMode.Sync || tensorsToKeep == null) {
return;
}
tensors.forEach(tensor => {
tensorsToKeep.forEach(tensor => {
if (tensor && !tensor.kept) {
keepTensors.push(tensor);
if (tensorsPendingDisposal) {
tensorsPendingDisposal.push(tensor);
}
keep(tensor);
}
});
Expand Down Expand Up @@ -254,8 +256,7 @@ export class GraphExecutor implements FunctionExecutor {
this.weightMap, tensorArrayMap, tensorListMap,
this.functionExecutorMap);
if (this.dumpMode === DumpMode.Sync) {
this.keepTensorsForExecute = [];
this.keepInputTensorsForExecute = [];
this.tensorsPendingDisposal = [];
}

Object.keys(inputs).forEach(name => {
Expand All @@ -264,7 +265,7 @@ export class GraphExecutor implements FunctionExecutor {
tensors[index] = inputs[name];
tensorsMap[nodeName] = tensors;
// Input tensors should be disposed by user.
this.keepTensors(this.keepInputTensorsForExecute, tensors);
this.keepTensors(tensors);
});

const tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
Expand All @@ -281,7 +282,7 @@ export class GraphExecutor implements FunctionExecutor {
`Please use model.executeAsync() instead.`);
}
tensorsMap[node.name] = tensors;
this.keepTensors(this.keepTensorsForExecute, tensors);
this.keepTensors(tensors, this.tensorsPendingDisposal);
this.checkTensorForDisposal(
node.name, node, tensorsMap, context, tensorsToKeep,
outputNodeNames, intermediateTensorConsumerCount);
Expand All @@ -297,7 +298,7 @@ export class GraphExecutor implements FunctionExecutor {
if (this.dumpMode === DumpMode.Sync) {
this.tensorsMap = tensorsMap;
} else {
this.tensorsMap = {};
this.tensorsMap = null;
}
return result;
}
Expand Down Expand Up @@ -342,15 +343,8 @@ export class GraphExecutor implements FunctionExecutor {
if (count === 1) {
if (this.dumpMode === DumpMode.Default) {
tensor.dispose();
} else {
const [nodeName, index] =
getNodeNameAndIndex(node.name, context);
if (this.intermediateTensors[nodeName]) {
this.intermediateTensors[nodeName][index] = tensor;
} else {
this.intermediateTensors[nodeName] = [];
this.intermediateTensors[nodeName][index] = tensor;
}
} else if (this.dumpMode === DumpMode.Async) {
this.tensorsPendingDisposal.push(tensor);
}
delete intermediateTensorConsumerCount[tensor.id];
} else if (count != null) {
Expand Down Expand Up @@ -380,20 +374,19 @@ export class GraphExecutor implements FunctionExecutor {
}

disposeIntermediateTensors() {
for (const key in this.intermediateTensors) {
this.intermediateTensors[key].forEach(tensor => tensor.dispose());
if (this.dumpMode === DumpMode.Default) {
return;
}
this.intermediateTensors = {};
if (this.dumpMode === DumpMode.Sync) {
if (this.keepTensorsForExecute) {
this.keepTensorsForExecute.forEach(tensor => {
tensor.dispose();
});
this.keepTensorsForExecute = null;
}
} else if (this.dumpMode === DumpMode.Async) {
if (this.tensorsPendingDisposal) {
this.tensorsPendingDisposal.forEach(tensor => {
tensor.dispose();
});
this.tensorsPendingDisposal = null;
}
if (this.dumpMode === DumpMode.Async) {
this.disposeTensorsMap();
}
this.tensorsMap = null;
this.dumpMode = DumpMode.Default;
}

Expand All @@ -406,7 +399,7 @@ export class GraphExecutor implements FunctionExecutor {
const tensorArray = this.tensorsMap[key];
tensorArray.forEach(tensor => {
if (tensor && !tensor.kept && !tensor.isDisposed &&
!this.keepIdsForExecuteAsync.has(tensor.id)) {
!this.idsKeepForAsync.has(tensor.id)) {
tensor.dispose();
}
});
Expand Down Expand Up @@ -448,6 +441,7 @@ export class GraphExecutor implements FunctionExecutor {
const keepTensorForDump = env().getBool('KEEP_INTERMEDIATE_TENSORS');
if (keepTensorForDump) {
this.dumpMode = DumpMode.Async;
this.tensorsPendingDisposal = [];
}
} catch (e) {
console.warn(e.message);
Expand All @@ -468,15 +462,15 @@ export class GraphExecutor implements FunctionExecutor {
// dispose all the intermediate tensors
const outputIds = results.map(t => t.id);
const inputIds = Object.keys(inputs).map(name => inputs[name].id);
this.keepIdsForExecuteAsync =
this.idsKeepForAsync =
new Set<number>([...outputIds, ...inputIds, ...this.weightIds]);
if (this.dumpMode !== DumpMode.Async) {
this.disposeTensorsMap();
}

// dispose the context for the root executor
if (this.parent == null) {
context.dispose(this.keepIdsForExecuteAsync);
context.dispose(this.idsKeepForAsync);
}

return results;
Expand Down

0 comments on commit ca5dbc9

Please sign in to comment.