diff --git a/image_classification/.eslintrc.js b/image_classification/.eslintrc.js
index c02d313a..41955769 100644
--- a/image_classification/.eslintrc.js
+++ b/image_classification/.eslintrc.js
@@ -1,6 +1,5 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
- 'tf': 'readonly',
},
};
diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js
index 9f0d87f4..af9c3a52 100644
--- a/image_classification/efficientnet_fp16_nchw.js
+++ b/image_classification/efficientnet_fp16_nchw.js
@@ -147,12 +147,9 @@ export class EfficientNetFP16Nchw {
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
- if (contextOptions.deviceType === 'npu') {
- return this.builder_.cast(await gemm, 'float32');
- } else {
- const softmax = this.builder_.softmax(await gemm);
- return this.builder_.cast(softmax, 'float32');
- }
+ const softmax = this.builder_.softmax(await gemm);
+
+ return this.builder_.cast(softmax, 'float32');
}
async build(outputOperand) {
diff --git a/image_classification/index.html b/image_classification/index.html
index 85f3b87f..9cd3e375 100644
--- a/image_classification/index.html
+++ b/image_classification/index.html
@@ -238,9 +238,6 @@
No model selected
-
diff --git a/image_classification/main.js b/image_classification/main.js
index 4c33ab6e..99cd3e7f 100644
--- a/image_classification/main.js
+++ b/image_classification/main.js
@@ -231,18 +231,6 @@ async function renderCamStream() {
// Get top 3 classes of labels from output buffer
function getTopClasses(buffer, labels) {
- // Currently we need to fallback softmax to tf.softmax because
- // NPU dosen't support softmax.
- // TODO: Remove this workaround once NPU supports softmax.
- if (deviceType === 'npu') {
- // Softmax
- buffer = tf.tidy(() => {
- const a =
- tf.tensor(buffer, netInstance.outputDimensions, 'float32');
- const b = tf.softmax(a);
- return b.dataSync();
- });
- }
const probs = Array.from(buffer);
const indexes = probs.map((prob, index) => [prob, index]);
const sorted = indexes.sort((a, b) => {
diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js
index 6b52350b..29e63e3f 100644
--- a/image_classification/mobilenet_nchw.js
+++ b/image_classification/mobilenet_nchw.js
@@ -153,12 +153,8 @@ export class MobileNetV2Nchw {
{groups: 1280, strides: [7, 7]});
const conv5 = this.buildConv_(await conv4, '104', false);
const reshape = this.builder_.reshape(await conv5, [1, 1000]);
- if (contextOptions.deviceType === 'npu') {
- return this.builder_.cast(reshape, 'float32');
- } else {
- const softmax = this.builder_.softmax(reshape);
- return this.builder_.cast(softmax, 'float32');
- }
+ const softmax = this.builder_.softmax(reshape);
+ return this.builder_.cast(softmax, 'float32');
}
}
diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js
index 58ff60be..b3f80f02 100644
--- a/image_classification/resnet50v1_fp16_nchw.js
+++ b/image_classification/resnet50v1_fp16_nchw.js
@@ -118,12 +118,8 @@ export class ResNet50V1FP16Nchw {
const pool2 = this.builder_.averagePool2d(await bottleneck16);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = this.buildGemm_(reshape, '0');
- if (contextOptions.deviceType === 'npu') {
- return this.builder_.cast(await gemm, 'float32');
- } else {
- const softmax = this.builder_.softmax(await gemm);
- return this.builder_.cast(softmax, 'float32');
- }
+ const softmax = this.builder_.softmax(await gemm);
+ return this.builder_.cast(softmax, 'float32');
}
async build(outputOperand) {