Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support uniforms for conv, conv transpose, conv grouped #18753

Merged
merged 30 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
17989b8
[js/webgpu] Support uniforms for conv, conv transpose, conv grouped
axinging Dec 4, 2023
ae9f9f6
Fix comments and remove clipMax clipMin todo
axinging Dec 14, 2023
bb8ddba
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 14, 2023
dc0d998
Format
axinging Dec 14, 2023
8bb80cc
Nit
axinging Dec 14, 2023
bb47410
Fix MatmulNaive
axinging Dec 15, 2023
988bf4d
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 15, 2023
a848409
Impl updateUniformsFromActivation
axinging Dec 15, 2023
603bbd7
Fix comments
axinging Dec 18, 2023
22d5fa9
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 18, 2023
6ef48e5
Throw error when type not support
axinging Dec 18, 2023
d470239
Nit
axinging Dec 18, 2023
e0f1bec
Nit
axinging Dec 19, 2023
c20faea
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 19, 2023
6d7cb0c
Remove activationFunction
axinging Dec 19, 2023
07784e1
REemove updateUniformsFromActivation
axinging Dec 19, 2023
bf89981
Nit
axinging Dec 19, 2023
e3bafdd
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 21, 2023
6a0f4e9
Fix comments
axinging Dec 22, 2023
afd5d7b
Remove wpt uniforms in conv back
axinging Dec 22, 2023
87eaf32
Add tile info in conv2d mm
axinging Dec 22, 2023
e40da4e
Reapply cacheKey
axinging Dec 22, 2023
7f0133e
Refactor getShaderSource and use snake case
axinging Dec 25, 2023
c22da4c
Merge branch 'main' into conv_convtrans_uniform
axinging Dec 25, 2023
1afdc9b
Nit
axinging Dec 25, 2023
621db06
Merge branch 'main' into conv_convtrans_uniform
axinging Jan 16, 2024
cd936f5
Format layer-norm
axinging Jan 17, 2024
6ddeb03
Merge branch 'main' into conv_convtrans_uniform
axinging Jan 17, 2024
d3be951
Workaround windows linux cr lf format issue when colum reaches 120
axinging Jan 18, 2024
d7f3603
Revert layer-norm
axinging Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
//
// modified to fit the needs of the project

import {tensorDataTypeEnumToString} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformDataElementType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {getActivationSnippet} from '../fuse-utils';

Expand Down Expand Up @@ -88,10 +89,10 @@ const conv2dCommonSnippet =
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;

let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
Expand Down Expand Up @@ -195,15 +196,33 @@ export const createConv2DMatMulProgramInfo =

// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x =
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];

programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
];
const uniforms: UniformsArrayType = [
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
{name: 'dilation', type: 'i32', length: 2}
];
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: tensorDataType, data: attributes.clipMax!}, {type: tensorDataType, data: attributes.clipMin!});
uniforms.push(
{name: 'clipMax', type: x.type.value as UniformDataElementType},
{name: 'clipMin', type: x.type.value as UniformDataElementType});
}

programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];

let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
Expand All @@ -218,6 +237,7 @@ export const createConv2DMatMulProgramInfo =
inputVariables.push(bias);

programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
Expand All @@ -226,9 +246,14 @@ export const createConv2DMatMulProgramInfo =
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
programUniforms.push(...createTensorShapeVariables(outputShape));

return {
name: 'Conv2DMatMul',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {
hint:
`${attributes.format};${attributes.activation};${innerElementSize};${fitAOuter};${fitBOuter};${fitInner}`,
inputDependencies
guschmue marked this conversation as resolved.
Show resolved Hide resolved
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
Expand All @@ -239,15 +264,7 @@ export const createConv2DMatMulProgramInfo =
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)}
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
${
conv2dCommonSnippet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
//
// modified to fit the needs of the project

import {tensorDataTypeEnumToString} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {getActivationSnippet} from '../fuse-utils';

Expand Down Expand Up @@ -74,21 +75,21 @@ const conv2dTransposeCommonSnippet =
col % outWidth);
`;

const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';

const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;

let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
let WRow = ${col} / (uniforms.filterDims[1] * inChannels);
let WCol = ${col} / inChannels % uniforms.filterDims[1];
let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]);
let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${type}(0.0);
}
Expand Down Expand Up @@ -116,9 +117,9 @@ const conv2dTransposeCommonSnippet =

const sampleW = `
let col = colIn * ${innerElementSize};
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'};
let coordX = uniforms.filterDims[0] - 1 - row / (uniforms.filterDims[1] * inChannels);
let coordY = uniforms.filterDims[1] - 1 - (row / inChannels) % uniforms.filterDims[1];
if (${
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
Expand Down Expand Up @@ -186,20 +187,50 @@ export const createConv2DTransposeMatMulProgramInfo =
const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const filterDims =
[attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
const effectiveFilterDims = [
filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)),
filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1))
];
const pads = [
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2)
];

const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const inputVariables = [x, w];
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
];
const uniforms: UniformsArrayType = [
{name: 'dimAOuter', type: 'i32'}, {name: 'dimBOuter', type: 'i32'}, {name: 'dimInner', type: 'i32'},
{name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2},
{name: 'filterDims', type: 'i32', length: filterDims.length}, {name: 'pads', type: 'i32', length: pads.length}
];
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: tensorDataType, data: attributes.clipMax!}, {type: tensorDataType, data: attributes.clipMin!});
uniforms.push(
{name: 'clipMax', type: x.type.value as UniformDataElementType},
{name: 'clipMin', type: x.type.value as UniformDataElementType});
}
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));

const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
let declareFunctions = '';
if (hasBias) {
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
Expand All @@ -211,40 +242,15 @@ export const createConv2DTransposeMatMulProgramInfo =

return {
name: 'Conv2DTransposeMatMul',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.format};${attributes.activation};${elementsPerThread}`, inputDependencies},
guschmue marked this conversation as resolved.
Show resolved Hide resolved
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms
}),
getShaderSource: (shaderHelper: ShaderHelper) => `
${utilFunctions('uniforms.result_strides')}
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)};
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${
attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
${
Expand Down
Loading
Loading