diff --git a/image_classification/.eslintrc.js b/image_classification/.eslintrc.js index c02d313a..41955769 100644 --- a/image_classification/.eslintrc.js +++ b/image_classification/.eslintrc.js @@ -1,6 +1,5 @@ module.exports = { globals: { 'MLGraphBuilder': 'readonly', - 'tf': 'readonly', }, }; diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js index 9f0d87f4..af9c3a52 100644 --- a/image_classification/efficientnet_fp16_nchw.js +++ b/image_classification/efficientnet_fp16_nchw.js @@ -147,12 +147,9 @@ export class EfficientNetFP16Nchw { const pool1 = this.builder_.averagePool2d(await conv22); const reshape = this.builder_.reshape(pool1, [1, 1280]); const gemm = this.buildGemm_(reshape, '0'); - if (contextOptions.deviceType === 'npu') { - return this.builder_.cast(await gemm, 'float32'); - } else { - const softmax = this.builder_.softmax(await gemm); - return this.builder_.cast(softmax, 'float32'); - } + const softmax = this.builder_.softmax(await gemm); + + return this.builder_.cast(softmax, 'float32'); } async build(outputOperand) { diff --git a/image_classification/index.html b/image_classification/index.html index 85f3b87f..9cd3e375 100644 --- a/image_classification/index.html +++ b/image_classification/index.html @@ -238,9 +238,6 @@

No model selected

- diff --git a/image_classification/main.js b/image_classification/main.js index 4c33ab6e..99cd3e7f 100644 --- a/image_classification/main.js +++ b/image_classification/main.js @@ -231,18 +231,6 @@ async function renderCamStream() { // Get top 3 classes of labels from output buffer function getTopClasses(buffer, labels) { - // Currently we need to fallback softmax to tf.softmax because - // NPU dosen't support softmax. - // TODO: Remove this workaround once NPU supports softmax. - if (deviceType === 'npu') { - // Softmax - buffer = tf.tidy(() => { - const a = - tf.tensor(buffer, netInstance.outputDimensions, 'float32'); - const b = tf.softmax(a); - return b.dataSync(); - }); - } const probs = Array.from(buffer); const indexes = probs.map((prob, index) => [prob, index]); const sorted = indexes.sort((a, b) => { diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js index 6b52350b..29e63e3f 100644 --- a/image_classification/mobilenet_nchw.js +++ b/image_classification/mobilenet_nchw.js @@ -153,12 +153,8 @@ export class MobileNetV2Nchw { {groups: 1280, strides: [7, 7]}); const conv5 = this.buildConv_(await conv4, '104', false); const reshape = this.builder_.reshape(await conv5, [1, 1000]); - if (contextOptions.deviceType === 'npu') { - return this.builder_.cast(reshape, 'float32'); - } else { - const softmax = this.builder_.softmax(reshape); - return this.builder_.cast(softmax, 'float32'); - } + const softmax = this.builder_.softmax(reshape); + return this.builder_.cast(softmax, 'float32'); } } diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js index 58ff60be..b3f80f02 100644 --- a/image_classification/resnet50v1_fp16_nchw.js +++ b/image_classification/resnet50v1_fp16_nchw.js @@ -118,12 +118,8 @@ export class ResNet50V1FP16Nchw { const pool2 = this.builder_.averagePool2d(await bottleneck16); const reshape = this.builder_.reshape(pool2, [1, 2048]); const gemm = this.buildGemm_(reshape, '0'); - if (contextOptions.deviceType === 'npu') { - return this.builder_.cast(await gemm, 'float32'); - } else { - const softmax = this.builder_.softmax(await gemm); - return this.builder_.cast(softmax, 'float32'); - } + const softmax = this.builder_.softmax(await gemm); + return this.builder_.cast(softmax, 'float32'); } async build(outputOperand) {