Skip to content

Commit

Permalink
leaky relu
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent dbe57fe commit a1fbcfd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
2 changes: 1 addition & 1 deletion js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Identity', '', '1+', unaryOps.identity],
// ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes],
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes],
// ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
// ['Less', '', '7+', binaryOps.less],
// ['Log', '', '6+', unaryOps.log],
// ['MatMul', '', '1+', matMul, parseMatMulAttributes],
Expand Down
31 changes: 18 additions & 13 deletions js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,28 @@ export const exp = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Pro
export const floor = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'floor'), inputs);

// export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs)];
export interface LeakyReluAttributes extends AttributeWithCacheKey {
readonly alpha: number;
}

// export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs)];
export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes):
Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(
inputs[0], 'leaky_relu', `
let leaky_relu_alpha_: f32 = f32(${attributes.alpha});
// export interface LeakyReluAttributes extends AttributeWithCacheKey {
// readonly alpha: number;
// }
fn leaky_relu_(a: f32) -> f32 {
return select(a, a * leaky_relu_alpha_, a < 0.0);
}
// export const leakyRelu =
// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: LeakyReluAttributes): Tensor[] => [handler.run(
// createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey),
// inputs)];
fn leaky_relu(v: vec4<f32>) -> vec4<f32> {
return vec4(leaky_relu_(v.x), leaky_relu_(v.y), leaky_relu_(v.z), leaky_relu_(v.w));
}`,
attributes.cacheKey),
inputs);

// export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)});
export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)});

// export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs)];
Expand Down
12 changes: 6 additions & 6 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
// "test_depthtospace_example",
"test_elu_example",
"test_elu",
"test_elu_default"
"test_elu_default",
// "test_flatten_axis0",
// "test_flatten_axis1",
// "test_flatten_axis2",
Expand All @@ -372,9 +372,9 @@
// "test_equal_bcast",
// "test_equal",
// "test_identity",
// "test_leakyrelu_default",
// "test_leakyrelu_example",
// "test_leakyrelu",
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu"
// "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test.
// "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test.
// "test_matmul_2d",
Expand Down Expand Up @@ -521,7 +521,7 @@
//"depth-to-space.jsonc",
//"equal.jsonc",
"exp.jsonc",
"floor.jsonc"
"floor.jsonc",
//"global-average-pool.jsonc",
//"gemm.jsonc",
//"greater.jsonc",
Expand All @@ -534,7 +534,7 @@
//"neg.jsonc",
//"not.jsonc",
//"or.jsonc",
//"leaky-relu.jsonc",
"leaky-relu.jsonc"
//"reduce-min.jsonc",
//"relu.jsonc",
//"pad.jsonc",
Expand Down

0 comments on commit a1fbcfd

Please sign in to comment.