Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor mask computation into a separate function
Browse files Browse the repository at this point in the history
mattsoulanille committed Sep 12, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent bdf6ed1 commit 111048c
Showing 2 changed files with 58 additions and 38 deletions.
94 changes: 57 additions & 37 deletions tfjs-layers/src/engine/topology.ts
Original file line number Diff line number Diff line change
@@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
// Ensure inputs are all the same type.
const inputsList = generic_utils.toList(inputs);

let allAreSymbolic = true;
for (const input of inputsList) {
if (!(input instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
let noneAreSymbolic = true;
for (const input of inputsList) {
if (input instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
const allAreSymbolic = checkAllSymbolic(inputs);
const noneAreSymbolic = checkNoneSymbolic(inputs);

if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError(
@@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable {

// Actually call the layer, collecting output(s), mask(s), and shape(s).
if (noneAreSymbolic) {
let output = this.call(inputs as Tensor | Tensor[], kwargs);
let output = this.call(inputs, kwargs);

// Apply masks to the output tensors if the layer supports it.
if (this.supportsMasking) {
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
const outputMask = this.computeMask(inputs as Tensor | Tensor[]);
if (output instanceof Array && outputMask instanceof Array) {
if (output.length !== outputMask.length) {
throw new Error(`${this.name} output ${output.length} tensors `
+ `but ${outputMask.length} masks for those tensors`);
}
for (let i = 0; i < output.length; i++) {
output[i].kerasMask = outputMask[i];
}
} else if (outputMask instanceof Array) {
throw new Error(`{this.name} output a single tensor `
+ `but ${outputMask.length} masks`);
} else if (output instanceof Array) {
for (const out of output) {
out.kerasMask = outputMask.clone();
}
outputMask.dispose(); // Only keep the clones to avoid leaking
} else {
output.kerasMask = outputMask;
}
this.setMaskMetadata(inputs, output);
}

// If the layer returns tensors from its inputs, unmodified,
@@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
If the input tensor(s) had no previous history,
this does nothing.
*/
this.addInboundNode(
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
this.addInboundNode(inputs, output, null, null,
inputShape, outputShape, kwargs);
this._refCount++;

@@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
return mask;
}

private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
previous_mask?: Tensor|Tensor[]): void {
if (!this.supportsMasking) {
return;
}

const outputMasks = this.computeMask(inputs, previous_mask);
if (outputs instanceof Array && outputMasks instanceof Array) {
if (outputs.length !== outputMasks.length) {
throw new Error(`${this.name} outputs ${outputs.length} tensors `
+ `but ${outputMasks.length} masks for those tensors`);
}
for (let i = 0; i < outputs.length; i++) {
outputs[i].kerasMask = outputMasks[i];
}
} else if (outputMasks instanceof Array) {
throw new Error(`{this.name} outputs a single tensor `
+ `but ${outputMasks.length} masks`);
} else if (outputs instanceof Array) {
throw new Error(`{this.name} outputs ${outputs.length} tensors `
+ `but only one mask`);
} else {
outputs.kerasMask = outputMasks;
}
}

/**
* Internal method to create an inbound node for the layer.
*
@@ -1666,3 +1660,29 @@ export function getSourceInputs(
}
}
}

type MaybeSymbolic = SymbolicTensor | Tensor;

function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is SymbolicTensor | SymbolicTensor[] {
let allAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}

function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is Tensor | Tensor[] {
let noneAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
2 changes: 1 addition & 1 deletion tfjs-layers/src/utils/generic_utils.ts
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@ export function singletonOrArray<T>(xs: T[]): T|T[] {
* @param x target object to be normalized.
*/
// tslint:disable-next-line:no-any
export function toList(x: any): any[] {
export function toList<T>(x: T|T[]): T[] {
if (Array.isArray(x)) {
return x;
}

0 comments on commit 111048c

Please sign in to comment.