From cfdc254e1b930b9403b9cbda1a9ed7210ec638ea Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 27 Jul 2022 14:34:01 +0800 Subject: [PATCH] webgpu: fix notEqual error (#6669) --- tfjs-backend-webgpu/src/binary_op_util.ts | 27 +++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 382f65cf6be..80a822afe1b 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -109,8 +109,31 @@ const INT_DIV_VEC4 = ` return vec4(resultTemp); `; -const NOT_EQUAL = 'return f32(a != b);'; -const NOT_EQUAL_VEC4 = 'return vec4(a != b);'; +const NOT_EQUAL = ` + if (isnan(a) || isnan(b)) { + return 1.0; + } + return f32(a != b); +`; +const NOT_EQUAL_VEC4 = ` + var result = vec4(a != b); + var isANaN = isnanVec4(a); + var isBNaN = isnanVec4(b); + if (isANaN.r || isBNaN.r) { + result.r = 1.0; + } + if (isANaN.g || isBNaN.g) { + result.g = 1.0; + } + if (isANaN.b || isBNaN.b) { + result.b = 1.0; + } + if (isANaN.a || isBNaN.a) { + result.a = 1.0; + } + + return result; +`; const POW = ` if(a < 0.0 && floor(b) < b) { return uniforms.NAN;