Skip to content

Commit

Permalink
webgpu: Fix a bug in softmax (#7607)
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 authored Apr 20, 2023
1 parent f35881e commit 78a3a03
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tfjs-backend-webgpu/src/softmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,17 @@ export class SoftmaxProgram implements WebGPUProgram {
}
workgroupBarrier();
let reduceSize = min(cols, blockSize);
for (var currSize = reduceSize >> 1; currSize > 0; currSize = currSize >> 1) {
var reduceSize = min(cols, blockSize);
for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
reduceSize = currSize + (reduceSize & 1);
if (tid < currSize) {
buf[tid] = max(buf[tid], buf[tid + currSize]);
buf[tid] = max(buf[tid], buf[tid + reduceSize]);
}
workgroupBarrier();
}
if (tid == 0) {
rowMaxShared = max(buf[0], buf[reduceSize - 1]);
rowMaxShared = buf[0];
}
workgroupBarrier();
Expand Down
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/softmax_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ describeWithFlags('softmax', ALL_ENVS, () => {
expectArraysClose(await y.data(), [0.5, 0.5]);
});

it('odd number of inputs', async () => {
const y = tf.softmax(tf.tensor1d([-400, -400, 0, -400, -400, -400, -400]));

expectArraysClose(await y.data(), [0, 0, 1, 0, 0, 0, 0]);
});

it('Huge difference between probabilities', async () => {
const y = tf.softmax(tf.tensor1d([-1000, +1000]));

Expand Down

0 comments on commit 78a3a03

Please sign in to comment.