Skip to content

Commit

Permalink
[webgpu] Update ELU_DER (#7745)
Browse files Browse the repository at this point in the history
  • Loading branch information
hujiajie authored Jun 8, 2023
1 parent 806dfea commit e8feff4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ const ATAN2 = 'let resultTemp = atan2(a, b);';
const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;';
const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;';
const DIV = 'let resultTemp = a / b;';
const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);';
const ELU_DER_VEC4 =
'return select(a * (b + vec4<f32>(1.0)), a, b >= vec4<f32>(0.));';
const ELU_DER = 'let resultTemp = select(a * (b + 1.0), a, b >= b - b);';
const EQUAL = 'return f32(a == b);';
const EQUAL_VEC4 = 'return vec4<f32>(a == b);';
const GREATER = 'return f32(a > b);';
Expand Down Expand Up @@ -247,7 +245,8 @@ export function getBinaryOpString(
doOpSnippet = DIV;
break;
case BinaryOpType.ELU_DER:
return useVec4 ? ELU_DER_VEC4 : ELU_DER;
doOpSnippet = ELU_DER;
break;
case BinaryOpType.EQUAL:
return useVec4 ? EQUAL_VEC4 : EQUAL;
case BinaryOpType.GREATER:
Expand Down

0 comments on commit e8feff4

Please sign in to comment.