Skip to content

Commit

Permalink
[js/webgpu] support Cast operator (microsoft#16489)
Browse files Browse the repository at this point in the history
### Description
support `Cast` operator for webgpu backend.

Cast operator for webgpu backend currently only supports f32, u32, i32
and bool.
  • Loading branch information
fs-eire authored Aug 19, 2023
1 parent cb8db69 commit b10674c
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 10 deletions.
1 change: 1 addition & 0 deletions web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Do not modify directly.*
| Atan | ai.onnx(7+) | |
| Atanh | ai.onnx(9+) | |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation |
| Cast | ai.onnx(6-8,9-12,13-18,19+) | |
| Ceil | ai.onnx(6-12,13+) | |
| Clip | ai.onnx(6-10,11,12,13+) | |
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
Expand Down
1 change: 1 addition & 0 deletions web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Atanh', [unaryOps.atanh]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
Expand Down
51 changes: 41 additions & 10 deletions web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type ElementwiseCustomExpression = (expression: string) => string;
type ElementwiseFunctionCall = BuiltinFunctionName|ElementwiseCustomExpression;

const createElementwiseProgramShader =
(shaderHelper: ShaderHelper, datasize: number, funcCall: ElementwiseFunctionCall,
additionalImplementation?: string): string => {
(shaderHelper: ShaderHelper, datasize: number, inputDataType: number, outputDataType: number,
funcCall: ElementwiseFunctionCall, additionalImplementation?: string): string => {
const vecSize = Math.ceil(datasize / 4);

let expression = '';
Expand All @@ -25,8 +25,8 @@ const createElementwiseProgramShader =
expression = funcCall('a');
}

const input = inputVariable('inputData', DataType.float, [vecSize], 4);
const output = outputVariable('outputData', DataType.float, [vecSize], 4);
const input = inputVariable('inputData', inputDataType, [vecSize], 4);
const output = outputVariable('outputData', outputDataType, [vecSize], 4);

return `
${shaderHelper.declareVariables(input, output)}
Expand All @@ -42,23 +42,23 @@ const createElementwiseProgramShader =
};

const createElementwiseProgramInfo =
(metadata: ProgramMetadata, input: TensorView, funcCall: ElementwiseFunctionCall,
(metadata: ProgramMetadata, input: TensorView, outputDataType: number, funcCall: ElementwiseFunctionCall,
additionalImplementation?: string): ProgramInfo => ({
...metadata,
getShaderSource: shaderHelper =>
createElementwiseProgramShader(shaderHelper, ShapeUtil.size(input.dims), funcCall, additionalImplementation),
outputs: [{dims: input.dims, dataType: input.dataType, gpuDataType: GpuDataType.default}],
getShaderSource: shaderHelper => createElementwiseProgramShader(
shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation),
outputs: [{dims: input.dims, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: (inputTensors) =>
({x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)})
});

const createElementwiseProgramInfoLoader =
(input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string,
cacheKey?: string): ProgramInfoLoader => {
cacheKey?: string, outputDataType: number = input.dataType): ProgramInfoLoader => {
const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: cacheKey};
return {
...metadata,
get: () => createElementwiseProgramInfo(metadata, input, funcCall, additionalImplementation)
get: () => createElementwiseProgramInfo(metadata, input, outputDataType, funcCall, additionalImplementation)
};
};

Expand Down Expand Up @@ -89,6 +89,37 @@ export const atanh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atanh', 'atanh'));
};

export interface CastAttributes extends AttributeWithCacheKey {
readonly to: number;
readonly saturate?: boolean;
}

export const parseCastAttributes = (attributes: Record<string, unknown>): CastAttributes =>
createAttributeWithCacheKey(attributes as {to: number});


export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float:
func = 'vec4<f32>';
break;
case DataType.uint32:
func = 'vec4<u32>';
break;
case DataType.int32:
func = 'vec4<i32>';
break;
case DataType.bool:
func = 'vec4<bool>';
break;
default:
throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${attributes.to}`);
}
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to));
};

export interface ClipAttributes extends AttributeWithCacheKey {
readonly min: number;
readonly max: number;
Expand Down
248 changes: 248 additions & 0 deletions web/test/data/ops/cast.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
[
{
"name": "Cast float32 to int32",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 6 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [0.0, 0.5, 100.0, -234.0, -7.99, 1000000000],
"dims": [2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 0, 100, -234, -7, 1000000000],
"dims": [2, 3],
"type": "int32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
]
}
]
},
{
"name": "Cast int32 to float32",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 1 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [0, 0, 100, -234, -7, 1000000000],
"dims": [2, 3],
"type": "int32"
}
],
"outputs": [
{
"data": [0, 0, 100, -234, -7, 1000000000],
"dims": [2, 3],
"type": "float32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
],
"outputs": [
{
"data": [1],
"dims": [],
"type": "float32"
}
]
}
]
},
{
"name": "Cast int32 to uint32",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 12 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [0, -1, 100, -234, -7, 1000000000],
"dims": [2, 3],
"type": "int32"
}
],
"outputs": [
{
"data": [0, 4294967295, 100, 4294967062, 4294967289, 1000000000],
"dims": [2, 3],
"type": "uint32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
],
"outputs": [
{
"data": [1],
"dims": [],
"type": "uint32"
}
]
}
]
},
{
"name": "Cast uint32 to int32",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 6 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [0, 4294967295, 100, 1000000000],
"dims": [2, 2],
"type": "uint32"
}
],
"outputs": [
{
"data": [0, -1, 100, 1000000000],
"dims": [2, 2],
"type": "int32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "uint32"
}
],
"outputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
]
}
]
},
{
"name": "Cast int32 to bool",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 9 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [0, 1, 1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [false, true, true, false],
"dims": [2, 2],
"type": "bool"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
],
"outputs": [
{
"data": [true],
"dims": [],
"type": "bool"
}
]
}
]
},
{
"name": "Cast bool to int32",
"operator": "Cast",
"attributes": [{ "name": "to", "type": "int", "data": 6 }],
"cases": [
{
"name": "2x3",
"inputs": [
{
"data": [false, true, true, false],
"dims": [2, 2],
"type": "bool"
}
],
"outputs": [
{
"data": [0, 1, 1, 0],
"dims": [2, 2],
"type": "int32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [true],
"dims": [],
"type": "bool"
}
],
"outputs": [
{
"data": [1],
"dims": [],
"type": "int32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@
"asin.jsonc",
"ceil.jsonc",
//"concat.jsonc",
"cast.jsonc",
"conv.jsonc",
"cos.jsonc",
"div.jsonc",
Expand Down

0 comments on commit b10674c

Please sign in to comment.