Skip to content

Commit

Permalink
Fix lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed May 7, 2024
1 parent 6c77171 commit b6d56a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
11 changes: 1 addition & 10 deletions object_detection/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,7 @@ async function drawOutput(inputElement, outputs, labels) {
boxesList, scoresList, classesList, labels);
} else {
// Draw output for Tiny Yolo V2 model
// Transpose 'nchw' output to 'nhwc' for postprocessing
let outputBuffer = outputs.output;
if (layout === 'nchw') {
outputBuffer = tf.tidy(() => {
const a =
tf.tensor(outputBuffer, netInstance.outputDimensions, 'float32');
const b = tf.transpose(a, [0, 2, 3, 1]);
return b.dataSync();
});
}
const outputBuffer = outputs.output;
const decodeOut = Yolo2Decoder.decodeYOLOv2({numClasses: 20},
outputBuffer, inputOptions.anchors);
const boxes = Yolo2Decoder.getBoxes(decodeOut, inputOptions.margin);
Expand Down
17 changes: 11 additions & 6 deletions object_detection/tiny_yolov2_nchw.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'use strict';

import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin, toHalf} from '../common/utils.js';
import {buildConstantByNpy, computePadding2DForAutoPad, weightsOrigin} from '../common/utils.js';

// Tiny Yolo V2 model with 'nchw' layout, trained on the Pascal VOC dataset.
export class TinyYoloV2Nchw {
Expand All @@ -23,20 +23,24 @@ export class TinyYoloV2Nchw {
}

async buildConv_(input, name) {
let biasName = `${this.weightsUrl_}ConvBnFusion_BN_B_BatchNormalization_B${name}.npy`;
let weightName = `${this.weightsUrl_}ConvBnFusion_W_convolution${name}_W.npy`;
let biasName =
`${this.weightsUrl_}ConvBnFusion_BN_B_BatchNormalization_B${name}.npy`;
let weightName =
`${this.weightsUrl_}ConvBnFusion_W_convolution${name}_W.npy`;
if (name === '8') {
biasName = `${this.weightsUrl_}convolution8_B.npy`;
weightName = `${this.weightsUrl_}convolution8_W.npy`;
}

const weight = await buildConstantByNpy(this.builder_, weightName, this.targetDataType_);
const weight = await buildConstantByNpy(
this.builder_, weightName, this.targetDataType_);
const options = {autoPad: 'same-upper'};
options.padding = computePadding2DForAutoPad(
/* nchw */[input.shape()[2], input.shape()[3]],
/* oihw */[weight.shape()[2], weight.shape()[3]],
options.strides, options.dilations, 'same-upper');
options.bias = await buildConstantByNpy(this.builder_, biasName, this.targetDataType_);
options.bias = await buildConstantByNpy(
this.builder_, biasName, this.targetDataType_);
const conv = this.builder_.conv2d(input, weight, options);
if (name === '8') {
return conv;
Expand Down Expand Up @@ -92,7 +96,8 @@ export class TinyYoloV2Nchw {
const conv6 = await this.buildConv_(pool5, '6');
const conv7 = await this.buildConv_(conv6, '7');
const conv = await this.buildConv_(conv7, '8');
const transpose = this.builder_.transpose(conv, {permutation: [0, 2, 3, 1]});
const transpose = this.builder_.transpose(
conv, {permutation: [0, 2, 3, 1]});
if (this.targetDataType_ === 'float16') {
return this.builder_.cast(transpose, 'float32');
} else {
Expand Down

0 comments on commit b6d56a0

Please sign in to comment.