diff --git a/tfjs-layers/src/base_callbacks.ts b/tfjs-layers/src/base_callbacks.ts index 50eea0c8a56..e79cbf11eb2 100644 --- a/tfjs-layers/src/base_callbacks.ts +++ b/tfjs-layers/src/base_callbacks.ts @@ -488,7 +488,8 @@ export function standardizeCallbacks( } // Convert custom callback configs to custom callback objects. const callbackConfigs = - generic_utils.toList(callbacks) as CustomCallbackArgs[]; + generic_utils.toList( + callbacks) as CustomCallbackArgs[]; return callbackConfigs.map( callbackConfig => new CustomCallback(callbackConfig, yieldEvery)); } diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 85ba63ccd66..ffa581560ba 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable { */ protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor| SymbolicTensor[]): void { - inputs = generic_utils.toList(inputs); + const inputsList = generic_utils.toList(inputs); if (this.inputSpec == null || this.inputSpec.length === 0) { return; } const inputSpec = generic_utils.toList(this.inputSpec); - if (inputs.length !== inputSpec.length) { + if (inputsList.length !== inputSpec.length) { throw new ValueError( `Layer ${this.name} expects ${inputSpec.length} inputs, ` + - `but it received ${inputs.length} input tensors. ` + + `but it received ${inputsList.length} input tensors. ` + `Input received: ${inputs}`); } - for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) { - const x = inputs[inputIndex]; + for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) { + const x = inputsList[inputIndex]; const spec: InputSpec = inputSpec[inputIndex]; if (spec == null) { continue; @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable { // Ensure inputs are all the same type. const inputsList = generic_utils.toList(inputs); - let allAreSymbolic = true; - for (const input of inputsList) { - if (!(input instanceof SymbolicTensor)) { - allAreSymbolic = false; - break; - } - } - let noneAreSymbolic = true; - for (const input of inputsList) { - if (input instanceof SymbolicTensor) { - noneAreSymbolic = false; - break; - } - } + const allAreSymbolic = checkAllSymbolic(inputs); + const noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError( @@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable { // Actually call the layer, collecting output(s), mask(s), and shape(s). if (noneAreSymbolic) { - let output = this.call(inputs as Tensor | Tensor[], kwargs); + let output = this.call(inputs, kwargs); // Apply masks to the output tensors if the layer supports it. if (this.supportsMasking) { // TODO(mattsoulanille): pass the input tensors' masks to computeMask - const outputMask = this.computeMask(inputs as Tensor | Tensor[]); - if (output instanceof Array && outputMask instanceof Array) { - if (output.length !== outputMask.length) { - throw new Error(`${this.name} output ${output.length} tensors ` - + `but ${outputMask.length} masks for those tensors`); - } - for (let i = 0; i < output.length; i++) { - output[i].kerasMask = outputMask[i]; - } - } else if (outputMask instanceof Array) { - throw new Error(`{this.name} output a single tensor ` - + `but ${outputMask.length} masks`); - } else if (output instanceof Array) { - for (const out of output) { - out.kerasMask = outputMask.clone(); - } - outputMask.dispose(); // Only keep the clones to avoid leaking - } else { - output.kerasMask = outputMask; - } + this.setMaskMetadata(inputs, output); } // If the layer returns tensors from its inputs, unmodified, @@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable { If the input tensor(s) had no previous history, this does nothing. */ - this.addInboundNode( - inputs as SymbolicTensor | SymbolicTensor[], output, null, null, + this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); this._refCount++; @@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable { return mask; } + private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[], + previousMask?: Tensor|Tensor[]): void { + if (!this.supportsMasking) { + return; + } + + const outputMasks = this.computeMask(inputs, previousMask); + if (outputs instanceof Array && outputMasks instanceof Array) { + if (outputs.length !== outputMasks.length) { + throw new Error(`${this.name} outputs ${outputs.length} tensors ` + + `but ${outputMasks.length} masks for those tensors`); + } + for (let i = 0; i < outputs.length; i++) { + outputs[i].kerasMask = outputMasks[i]; + } + } else if (outputMasks instanceof Array) { + throw new Error(`{this.name} outputs a single tensor ` + + `but ${outputMasks.length} masks`); + } else if (outputs instanceof Array) { + throw new Error(`{this.name} outputs ${outputs.length} tensors ` + + `but only one mask`); + } else { + outputs.kerasMask = outputMasks; + } + } + /** * Internal method to create an inbound node for the layer. * @@ -1666,3 +1660,29 @@ export function getSourceInputs( } } } + +type MaybeSymbolic = SymbolicTensor | Tensor; + +function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is SymbolicTensor | SymbolicTensor[] { + let allAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (!(tensor instanceof SymbolicTensor)) { + allAreSymbolic = false; + break; + } + } + return allAreSymbolic; +} + +function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is Tensor | Tensor[] { + let noneAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (tensor instanceof SymbolicTensor) { + noneAreSymbolic = false; + break; + } + } + return noneAreSymbolic; +} diff --git a/tfjs-layers/src/utils/generic_utils.ts b/tfjs-layers/src/utils/generic_utils.ts index 2d06342c8e1..c128097f5ed 100644 --- a/tfjs-layers/src/utils/generic_utils.ts +++ b/tfjs-layers/src/utils/generic_utils.ts @@ -76,7 +76,7 @@ export function singletonOrArray(xs: T[]): T|T[] { * @param x target object to be normalized. */ // tslint:disable-next-line:no-any -export function toList(x: any): any[] { +export function toList(x: T|T[]): T[] { if (Array.isArray(x)) { return x; }