Skip to content

Commit

Permalink
webgpu: Tighten the condition to go the plain argminmax (#7742)
Browse files Browse the repository at this point in the history
* webgpu: Tighten the condition to go the plain argminmax

In PR #6778, I ever added the plain argminmax shader to improve the perf
of ArgMax[1, 1025, 2049, 19]. However, I accidently added the output
size condition, which was not verified. This PR removes that condition
and only goes to plain argminmax when the reduce length is very small(<32).
With this change, ArgMax[1,513,513,151] becomes 7.19ms from 13.24ms in
DeepLabV3-ade20k on Intel ADL.
  • Loading branch information
qjia7 authored Jun 8, 2023
1 parent e8feff4 commit 5ed9eb8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions tfjs-backend-webgpu/src/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ export class ArgMinMaxProgram implements WebGPUProgram {
this.dispatchLayout = flatDispatchLayout(this.outputShape);
// The shared algorithm is mainly used for large reduce size. It fully
// utilizes the threads in one workgroup to do the reduction. However,
// when the reduce size is very small or the output shape is too large. It's
// better to use the plain algorithm to reduce the number of workgroups to
// speedup. The threthold can be further tuned.
if (util.sizeFromShape(reduceShape) < 32 ||
util.sizeFromShape(outputShape) > 1000) {
// when the reduce size is very small, it's better to use the plain
// algorithm to reduce the number of workgroups to speedup. The threthold
// can be further tuned.
if (util.sizeFromShape(reduceShape) < 32) {
this.type = 'plain';
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
Expand Down
4 changes: 3 additions & 1 deletion tfjs-backend-webgpu/src/benchmark_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ describeWebGPU('Ops benchmarks', () => {
expect().nothing();
}

it('argMax', async () => {
// Failing on MacOS Timeout
// tslint:disable-next-line: ban
xit('argMax', async () => {
const n = 2;
const doTest = async (axis: number) => {
const tensors = new Array(n);
Expand Down

0 comments on commit 5ed9eb8

Please sign in to comment.