From 0c50cd52653680fae77ff42e180621404dd04b27 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 1 Mar 2023 11:23:13 +0800 Subject: [PATCH] webgpu: Fix conv2d gradient cases failures --- tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts b/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts index f104a23f6f6..3c99755cb50 100644 --- a/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts +++ b/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts @@ -55,7 +55,8 @@ export function conv2DBackpropInput(args: { ]; let program: Conv2DDerInputProgram|Conv2DDerInputMMProgram; // TODO: Experiment when to use Conv2DDerInputMMProgram algorithm. - if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE')) { + if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE') || + convInfo.dataFormat !== 'channelsLast') { program = new Conv2DDerInputProgram(convInfo); } else { program = new Conv2DDerInputMMProgram(convInfo);