Skip to content

Commit

Permalink
Merge pull request webmachinelearning#237 from Honry/remove-softmax
Browse files Browse the repository at this point in the history
Remove softmax workaround for NPU
  • Loading branch information
huningxin authored May 15, 2024
2 parents 4228ce8 + 3341168 commit f4bb7c0
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 34 deletions.
1 change: 0 additions & 1 deletion image_classification/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
'tf': 'readonly',
},
};
9 changes: 3 additions & 6 deletions image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,9 @@ export class EfficientNetFP16Nchw {
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
} else {
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(await gemm);

return this.builder_.cast(softmax, 'float32');
}

async build(outputOperand) {
Expand Down
3 changes: 0 additions & 3 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,6 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js"
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8="
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV"
crossorigin="anonymous"></script>
Expand Down
12 changes: 0 additions & 12 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,6 @@ async function renderCamStream() {

// Get top 3 classes of labels from output buffer
function getTopClasses(buffer, labels) {
// Currently we need to fallback softmax to tf.softmax because
// NPU dosen't support softmax.
// TODO: Remove this workaround once NPU supports softmax.
if (deviceType === 'npu') {
// Softmax
buffer = tf.tidy(() => {
const a =
tf.tensor(buffer, netInstance.outputDimensions, 'float32');
const b = tf.softmax(a);
return b.dataSync();
});
}
const probs = Array.from(buffer);
const indexes = probs.map((prob, index) => [prob, index]);
const sorted = indexes.sort((a, b) => {
Expand Down
8 changes: 2 additions & 6 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,8 @@ export class MobileNetV2Nchw {
{groups: 1280, strides: [7, 7]});
const conv5 = this.buildConv_(await conv4, '104', false);
const reshape = this.builder_.reshape(await conv5, [1, 1000]);
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(reshape, 'float32');
} else {
const softmax = this.builder_.softmax(reshape);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(reshape);
return this.builder_.cast(softmax, 'float32');
}
}

Expand Down
8 changes: 2 additions & 6 deletions image_classification/resnet50v1_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ export class ResNet50V1FP16Nchw {
const pool2 = this.builder_.averagePool2d(await bottleneck16);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
} else {
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}

async build(outputOperand) {
Expand Down

0 comments on commit f4bb7c0

Please sign in to comment.