Skip to content

Commit

Permalink
Refactor mask computation into a separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsoulanille committed Sep 12, 2023
1 parent bdf6ed1 commit 6822278
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 44 deletions.
3 changes: 2 additions & 1 deletion tfjs-layers/src/base_callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ export function standardizeCallbacks(
}
// Convert custom callback configs to custom callback objects.
const callbackConfigs =
generic_utils.toList(callbacks) as CustomCallbackArgs[];
generic_utils.toList<BaseCallback | CustomCallbackArgs>(
callbacks) as CustomCallbackArgs[];
return callbackConfigs.map(
callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
}
Expand Down
104 changes: 62 additions & 42 deletions tfjs-layers/src/engine/topology.ts
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable {
*/
protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor|
SymbolicTensor[]): void {
inputs = generic_utils.toList(inputs);
const inputsList = generic_utils.toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
const inputSpec = generic_utils.toList(this.inputSpec);
if (inputs.length !== inputSpec.length) {
if (inputsList.length !== inputSpec.length) {
throw new ValueError(
`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
`but it received ${inputs.length} input tensors. ` +
`but it received ${inputsList.length} input tensors. ` +
`Input received: ${inputs}`);
}
for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
const x = inputs[inputIndex];
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
const x = inputsList[inputIndex];
const spec: InputSpec = inputSpec[inputIndex];
if (spec == null) {
continue;
Expand Down Expand Up @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
// Ensure inputs are all the same type.
const inputsList = generic_utils.toList(inputs);

let allAreSymbolic = true;
for (const input of inputsList) {
if (!(input instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
let noneAreSymbolic = true;
for (const input of inputsList) {
if (input instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
const allAreSymbolic = checkAllSymbolic(inputs);
const noneAreSymbolic = checkNoneSymbolic(inputs);

if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError(
Expand Down Expand Up @@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable {

// Actually call the layer, collecting output(s), mask(s), and shape(s).
if (noneAreSymbolic) {
let output = this.call(inputs as Tensor | Tensor[], kwargs);
let output = this.call(inputs, kwargs);

// Apply masks to the output tensors if the layer supports it.
if (this.supportsMasking) {
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
const outputMask = this.computeMask(inputs as Tensor | Tensor[]);
if (output instanceof Array && outputMask instanceof Array) {
if (output.length !== outputMask.length) {
throw new Error(`${this.name} output ${output.length} tensors `
+ `but ${outputMask.length} masks for those tensors`);
}
for (let i = 0; i < output.length; i++) {
output[i].kerasMask = outputMask[i];
}
} else if (outputMask instanceof Array) {
throw new Error(`{this.name} output a single tensor `
+ `but ${outputMask.length} masks`);
} else if (output instanceof Array) {
for (const out of output) {
out.kerasMask = outputMask.clone();
}
outputMask.dispose(); // Only keep the clones to avoid leaking
} else {
output.kerasMask = outputMask;
}
this.setMaskMetadata(inputs, output);
}

// If the layer returns tensors from its inputs, unmodified,
Expand Down Expand Up @@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
If the input tensor(s) had no previous history,
this does nothing.
*/
this.addInboundNode(
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
this.addInboundNode(inputs, output, null, null,
inputShape, outputShape, kwargs);
this._refCount++;

Expand Down Expand Up @@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
return mask;
}

private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
previousMask?: Tensor|Tensor[]): void {
if (!this.supportsMasking) {
return;
}

const outputMasks = this.computeMask(inputs, previousMask);
if (outputs instanceof Array && outputMasks instanceof Array) {
if (outputs.length !== outputMasks.length) {
throw new Error(`${this.name} outputs ${outputs.length} tensors `
+ `but ${outputMasks.length} masks for those tensors`);
}
for (let i = 0; i < outputs.length; i++) {
outputs[i].kerasMask = outputMasks[i];
}
} else if (outputMasks instanceof Array) {
throw new Error(`{this.name} outputs a single tensor `
+ `but ${outputMasks.length} masks`);
} else if (outputs instanceof Array) {
throw new Error(`{this.name} outputs ${outputs.length} tensors `
+ `but only one mask`);
} else {
outputs.kerasMask = outputMasks;
}
}

/**
* Internal method to create an inbound node for the layer.
*
Expand Down Expand Up @@ -1666,3 +1660,29 @@ export function getSourceInputs(
}
}
}

type MaybeSymbolic = SymbolicTensor | Tensor;

function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is SymbolicTensor | SymbolicTensor[] {
let allAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}

function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is Tensor | Tensor[] {
let noneAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
2 changes: 1 addition & 1 deletion tfjs-layers/src/utils/generic_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export function singletonOrArray<T>(xs: T[]): T|T[] {
* @param x target object to be normalized.
*/
// tslint:disable-next-line:no-any
export function toList(x: any): any[] {
export function toList<T>(x: T|T[]): T[] {
if (Array.isArray(x)) {
return x;
}
Expand Down

0 comments on commit 6822278

Please sign in to comment.