From a1fbcfd51ad59b0cf7136e9cddf0f516557458dc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 1 Apr 2022 17:24:10 -0700 Subject: [PATCH] leaky relu --- .../backends/webgpu/op-resolve-rules.ts | 2 +- .../onnxjs/backends/webgpu/ops/unary-op.ts | 31 +++++++++++-------- js/web/test/suite-test-list.jsonc | 12 +++---- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index a47e5144ae250..4717b4944c019 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -38,7 +38,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['Identity', '', '1+', unaryOps.identity], // ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes], // ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes], - // ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes], + ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes], // ['Less', '', '7+', binaryOps.less], // ['Log', '', '6+', unaryOps.log], // ['MatMul', '', '1+', matMul, parseMatMulAttributes], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts index 4c321975bf35c..1a8b738125554 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts @@ -134,23 +134,28 @@ export const exp = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Pro export const floor = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => handler.run(createElementwiseProgramInfoLoader(inputs[0], 'floor'), inputs); -// export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): -// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs)]; +export interface LeakyReluAttributes extends AttributeWithCacheKey { + readonly alpha: number; +} -// export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): -// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs)]; +export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes): + Promise=>handler.run( + createElementwiseProgramInfoLoader( + inputs[0], 'leaky_relu', ` + let leaky_relu_alpha_: f32 = f32(${attributes.alpha}); -// export interface LeakyReluAttributes extends AttributeWithCacheKey { -// readonly alpha: number; -// } + fn leaky_relu_(a: f32) -> f32 { + return select(a, a * leaky_relu_alpha_, a < 0.0); + } -// export const leakyRelu = -// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: LeakyReluAttributes): Tensor[] => [handler.run( -// createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey), -// inputs)]; + fn leaky_relu(v: vec4) -> vec4 { + return vec4(leaky_relu_(v.x), leaky_relu_(v.y), leaky_relu_(v.z), leaky_relu_(v.w)); + }`, + attributes.cacheKey), + inputs); -// export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes => -// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)}); +export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes => + createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)}); // export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): // Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs)]; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index faa4b7f9ea6e0..efd6b00d6f300 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -349,7 +349,7 @@ // "test_depthtospace_example", "test_elu_example", "test_elu", - "test_elu_default" + "test_elu_default", // "test_flatten_axis0", // "test_flatten_axis1", // "test_flatten_axis2", @@ -372,9 +372,9 @@ // "test_equal_bcast", // "test_equal", // "test_identity", - // "test_leakyrelu_default", - // "test_leakyrelu_example", - // "test_leakyrelu", + "test_leakyrelu_default", + "test_leakyrelu_example", + "test_leakyrelu" // "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test. // "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test. // "test_matmul_2d", @@ -521,7 +521,7 @@ //"depth-to-space.jsonc", //"equal.jsonc", "exp.jsonc", - "floor.jsonc" + "floor.jsonc", //"global-average-pool.jsonc", //"gemm.jsonc", //"greater.jsonc", @@ -534,7 +534,7 @@ //"neg.jsonc", //"not.jsonc", //"or.jsonc", - //"leaky-relu.jsonc", + "leaky-relu.jsonc" //"reduce-min.jsonc", //"relu.jsonc", //"pad.jsonc",