From 17400ce4fc296be5b40b3b7ebb2c9bdbdb98168c Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 18 Mar 2024 14:16:03 +0800 Subject: [PATCH] Use reshape to re-implement squeeze --- nsnet2/nsnet2.js | 10 ++++++++-- style_transfer/fast_style_transfer_net.js | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/nsnet2/nsnet2.js b/nsnet2/nsnet2.js index 7d41d4c2..0d862498 100644 --- a/nsnet2/nsnet2.js +++ b/nsnet2/nsnet2.js @@ -50,7 +50,10 @@ export class NSNet2 { }); const [gru94, gru93] = this.builder_.gru(transpose31, weight192, recurrentWeight193, frames, this.hiddenSize, {bias: bias194, recurrentBias: recurrentBias194, initialHiddenState: initialState92, returnSequence: true}); - const squeeze95 = this.builder_.squeeze(gru93, {axes: [1]}); + // Use reshape to implement squeeze(gru93, {axes: [1]}); + const squeezed95_shape = gru93.shape(); + squeezed95_shape.splice(1, 1); + const squeeze95 = this.builder_.reshape(gru93, squeezed95_shape); const initialState155 = this.builder_.input('initialState155', { type: 'float32', dataType: 'float32', @@ -58,7 +61,10 @@ export class NSNet2 { }); const [gru157, gru156] = this.builder_.gru(squeeze95, weight212, recurrentWeight213, frames, this.hiddenSize, {bias: bias214, recurrentBias: recurrentBias214, initialHiddenState: initialState155, returnSequence: true}); - const squeeze158 = this.builder_.squeeze(gru156, {axes: [1]}); + // Use reshape to implement squeeze(gru156, {axes: [1]}); + const squeeze158_shape = gru156.shape(); + squeeze158_shape.splice(1, 1); + const squeeze158 = this.builder_.reshape(gru156, squeeze158_shape); const transpose159 = this.builder_.transpose(squeeze158, {permutation: [1, 0, 2]}); const relu163 = this.builder_.relu(this.builder_.add(this.builder_.matmul(transpose159, weight215), biasFcOut0)); const relu167 = this.builder_.relu(this.builder_.add(this.builder_.matmul(relu163, weight216), biasFcOut2)); diff --git a/style_transfer/fast_style_transfer_net.js b/style_transfer/fast_style_transfer_net.js index 532442f7..7e9fddeb 100644 --- a/style_transfer/fast_style_transfer_net.js +++ b/style_transfer/fast_style_transfer_net.js @@ -25,8 +25,14 @@ export class FastStyleTransferNet { buildInstanceNormalization_(conv2D, variableMul, variableAdd) { if ('instanceNormalization' in this.builder_) { - return this.builder_.instanceNormalization(conv2D, - {scale: this.builder_.squeeze(variableMul), bias: this.builder_.squeeze(variableAdd)}); + // Use reshape to implement squeeze(variableMul); and squeeze(variableAdd); + const mul_shape = variableMul.shape(); + const add_shape = variableAdd.shape(); + const squeezed_mul_shape = mul_shape.filter(dim => dim !==1); + const squeezed_add_shape = add_shape.filter(dim => dim !==1); + const mul_squeeze = this.builder_.reshape(variableMul, squeezed_mul_shape); + const add_squeeze = this.builder_.reshape(variableAdd, squeezed_add_shape); + return this.builder_.instanceNormalization(conv2D, {scale: mul_squeeze, bias: add_squeeze}); } else { const sub = this.builder_.sub(conv2D, this.builder_.reduceMean(conv2D, {axes: [2, 3], keepDimensions: true})); const reduceMean = this.builder_.reduceMean(this.builder_.mul(sub, sub), {axes: [2, 3], keepDimensions: true});