Skip to content

Commit

Permalink
optimize types
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent e9775fe commit 073695f
Showing 1 changed file with 41 additions and 54 deletions.
95 changes: 41 additions & 54 deletions js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,27 @@ import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {WORKGROUP_SIZE} from './common';

type ElementwiseFunctionImplementation =
// name, builtin function call.
// eg. ['Abs', 'abs']
[string, string]|
// name, function call builder, extra implementation (optional)
// eg. ['Neg', a => `-${a}`]
[string, (variableName: string) => string, string?];
type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
type ElementwiseFunctionCall = BuiltinFunctionName|ElementwiseCustomExpression;

const createElementwiseProgramShader =
(functionImplementation: ElementwiseFunctionImplementation, datasize: number): string => {
(datasize: number, funcCall: ElementwiseFunctionCall, additionalImplementation?: string): string => {
const vecSize = Math.ceil(datasize / 4);
let funcImpl: string;
let funcCall = functionImplementation[1];
if (typeof funcCall === 'function') {
funcImpl = functionImplementation[2] ?? '';
funcCall = funcCall('a');

let expression = '';
if (typeof funcCall === 'string') {
expression = `${funcCall}(a)`;
} else {
funcImpl = '';
funcCall = `${funcCall}(a)`;
expression = funcCall('a');
}
return `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> inputData : array<vec4<f32>>;
@group(0) @binding(1) var<storage, write> outputData : array<vec4<f32>>;
${funcImpl}
${additionalImplementation ?? ''}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
Expand All @@ -46,39 +40,41 @@ const createElementwiseProgramShader =
}
let a = inputData[global_id.x];
outputData[global_id.x] = ${funcCall};
outputData[global_id.x] = ${expression};
}`;
};

const createElementwiseProgramInfo =
(metadata: ProgramMetadata, input: Tensor, functionImplementation: ElementwiseFunctionImplementation):
(metadata: ProgramMetadata, input: Tensor, funcCall: ElementwiseFunctionCall, additionalImplementation?: string):
ProgramInfo => ({
...metadata,
shaderSource: createElementwiseProgramShader(functionImplementation, input.size),
shaderSource: createElementwiseProgramShader(input.size, funcCall, additionalImplementation),
outputs: [{dims: input.dims, type: input.type, gpuDataType: GpuDataType.default}],
dispatchGroup: (inputTensors) =>
({x: Math.ceil(inputTensors[0].size / 64 /* workgroup size */ / 4 /* vec size */)})
});

const createElementwiseProgramInfoLoader =
(input: Tensor, functionImplementation: ElementwiseFunctionImplementation,
(input: Tensor, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string,
cacheKey?: string): ProgramInfoLoader => {
const metadata:
ProgramMetadata = {name: functionImplementation[0], inputTypes: [GpuDataType.default], cacheHint: cacheKey};
return {...metadata, get: () => createElementwiseProgramInfo(metadata, input, functionImplementation)};
const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: cacheKey};
return {
...metadata,
get: () => createElementwiseProgramInfo(metadata, input, funcCall, additionalImplementation)
};
};

export const abs = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Abs', 'abs']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Abs', 'abs'), inputs);

export const acos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Acos', 'acos']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Acos', 'acos'), inputs);

export const asin = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Asin', 'asin']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Asin', 'asin'), inputs);

export const atan = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Atan', 'atan']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Atan', 'atan'), inputs);

export interface ClipAttributes extends AttributeWithCacheKey {
readonly min: number;
Expand All @@ -88,13 +84,10 @@ export interface ClipAttributes extends AttributeWithCacheKey {
export const clip = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ClipAttributes):
Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(
inputs[0],
[
'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
let clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
let clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
`
],
`,
attributes.cacheKey),
inputs);

Expand All @@ -118,10 +111,10 @@ export const clipV11 = async(handler: WebGpuInferenceHandler, inputs: Tensor[]):
};

export const ceil = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Ceil', 'ceil']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Ceil', 'ceil'), inputs);

export const cos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Cos', 'cos']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Cos', 'cos'), inputs);

export interface EluAttributes extends AttributeWithCacheKey {
readonly alpha: number;
Expand All @@ -130,9 +123,7 @@ export interface EluAttributes extends AttributeWithCacheKey {
export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes):
Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(
inputs[0],
[
'Elu', a => `elu_vf32(${a})`, `
inputs[0], 'Elu', a => `elu_vf32(${a})`, `
let elu_alpha_: f32 = f32(${attributes.alpha});
fn elu_f32(a: f32) -> f32 {
Expand All @@ -141,19 +132,18 @@ export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attr
fn elu_vf32(v: vec4<f32>) -> vec4<f32> {
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
}`
],
}`,
attributes.cacheKey),
inputs);

export const parseEluAttributes = (node: Graph.Node): EluAttributes =>
createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)});

export const exp = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Exp', 'exp']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Exp', 'exp'), inputs);

export const floor = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Floor', 'floor']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Floor', 'floor'), inputs);

export interface LeakyReluAttributes extends AttributeWithCacheKey {
readonly alpha: number;
Expand All @@ -162,9 +152,7 @@ export interface LeakyReluAttributes extends AttributeWithCacheKey {
export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes):
Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(
inputs[0],
[
'LeakyRelu', a => `leaky_relu_vf32(${a})`, `
inputs[0], 'LeakyRelu', a => `leaky_relu_vf32(${a})`, `
let leaky_relu_alpha_: f32 = f32(${attributes.alpha});
fn leaky_relu_f32(a: f32) -> f32 {
Expand All @@ -173,37 +161,36 @@ export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[]
fn leaky_relu_vf32(v: vec4<f32>) -> vec4<f32> {
return vec4(leaky_relu_f32(v.x), leaky_relu_f32(v.y), leaky_relu_f32(v.z), leaky_relu_f32(v.w));
}`
],
}`,
attributes.cacheKey),
inputs);

export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)});

export const log = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Log', 'log']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Log', 'log'), inputs);

export const neg = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Neg', a => `-${a}`]), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Neg', a => `-${a}`), inputs);

// export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs)];

export const relu = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(inputs[0], ['Relu', a => `max(${a}, vec4(0.0))`]), inputs);
createElementwiseProgramInfoLoader(inputs[0], 'Relu', a => `max(${a}, vec4(0.0))`), inputs);

export const sigmoid = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(inputs[0], ['Sigmoid', a => `(vec4(1.0) / (vec4(1.0) + exp(-${a})))`]), inputs);
createElementwiseProgramInfoLoader(inputs[0], 'Sigmoid', a => `(vec4(1.0) / (vec4(1.0) + exp(-${a})))`), inputs);

export const sin = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Sin', 'sin']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Sin', 'sin'), inputs);

export const sqrt = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Sqrt', 'sqrt']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Sqrt', 'sqrt'), inputs);

export const tan = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Tan', 'tan']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Tan', 'tan'), inputs);

export const tanh = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Tanh', 'tanh']), inputs);
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Tanh', 'tanh'), inputs);

0 comments on commit 073695f

Please sign in to comment.