Skip to content

Commit

Permalink
Use reshape to re-implement squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Mar 18, 2024
1 parent dbfab8f commit 17400ce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 8 additions & 2 deletions nsnet2/nsnet2.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,21 @@ export class NSNet2 {
});
const [gru94, gru93] = this.builder_.gru(transpose31, weight192, recurrentWeight193, frames, this.hiddenSize,
{bias: bias194, recurrentBias: recurrentBias194, initialHiddenState: initialState92, returnSequence: true});
const squeeze95 = this.builder_.squeeze(gru93, {axes: [1]});
// Use reshape to implement squeeze(gru93, {axes: [1]});
const squeezed95_shape = gru93.shape();

Check failure on line 54 in nsnet2/nsnet2.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'squeezed95_shape' is not in camel case
squeezed95_shape.splice(1, 1);
const squeeze95 = this.builder_.reshape(gru93, squeezed95_shape);
const initialState155 = this.builder_.input('initialState155', {
type: 'float32',
dataType: 'float32',
dimensions: [1, batchSize, this.hiddenSize],
});
const [gru157, gru156] = this.builder_.gru(squeeze95, weight212, recurrentWeight213, frames, this.hiddenSize,
{bias: bias214, recurrentBias: recurrentBias214, initialHiddenState: initialState155, returnSequence: true});
const squeeze158 = this.builder_.squeeze(gru156, {axes: [1]});
// Use reshape to implement squeeze(gru156, {axes: [1]});
const squeeze158_shape = gru156.shape();

Check failure on line 65 in nsnet2/nsnet2.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'squeeze158_shape' is not in camel case
squeeze158_shape.splice(1, 1);
const squeeze158 = this.builder_.reshape(gru156, squeeze158_shape);
const transpose159 = this.builder_.transpose(squeeze158, {permutation: [1, 0, 2]});
const relu163 = this.builder_.relu(this.builder_.add(this.builder_.matmul(transpose159, weight215), biasFcOut0));
const relu167 = this.builder_.relu(this.builder_.add(this.builder_.matmul(relu163, weight216), biasFcOut2));
Expand Down
10 changes: 8 additions & 2 deletions style_transfer/fast_style_transfer_net.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ export class FastStyleTransferNet {

buildInstanceNormalization_(conv2D, variableMul, variableAdd) {
if ('instanceNormalization' in this.builder_) {
return this.builder_.instanceNormalization(conv2D,
{scale: this.builder_.squeeze(variableMul), bias: this.builder_.squeeze(variableAdd)});
// Use reshape to implement squeeze(variableMul); and squeeze(variableAdd);
const mul_shape = variableMul.shape();

Check failure on line 29 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'mul_shape' is not in camel case
const add_shape = variableAdd.shape();

Check failure on line 30 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'add_shape' is not in camel case
const squeezed_mul_shape = mul_shape.filter(dim => dim !==1);

Check failure on line 31 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'squeezed_mul_shape' is not in camel case

Check failure on line 31 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Expected parentheses around arrow function argument
const squeezed_add_shape = add_shape.filter(dim => dim !==1);

Check failure on line 32 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'squeezed_add_shape' is not in camel case

Check failure on line 32 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Expected parentheses around arrow function argument
const mul_squeeze = this.builder_.reshape(variableMul, squeezed_mul_shape);

Check failure on line 33 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'mul_squeeze' is not in camel case
const add_squeeze = this.builder_.reshape(variableAdd, squeezed_add_shape);

Check failure on line 34 in style_transfer/fast_style_transfer_net.js

View workflow job for this annotation

GitHub Actions / job (ubuntu-latest)

Identifier 'add_squeeze' is not in camel case
return this.builder_.instanceNormalization(conv2D, {scale: mul_squeeze, bias: add_squeeze});
} else {
const sub = this.builder_.sub(conv2D, this.builder_.reduceMean(conv2D, {axes: [2, 3], keepDimensions: true}));
const reduceMean = this.builder_.reduceMean(this.builder_.mul(sub, sub), {axes: [2, 3], keepDimensions: true});
Expand Down

0 comments on commit 17400ce

Please sign in to comment.