Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support uniforms for instance-norm #18929

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
Expand Down Expand Up @@ -82,7 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
Expand Down
175 changes: 99 additions & 76 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

export interface InstanceNormAttributes extends AttributeWithCacheKey {
export interface InstanceNormAttributes {
epsilon: number;
format: 'NHWC'|'NCHW';
}
Expand All @@ -21,41 +20,48 @@ const metadata = {
const createInstanceNormProgramInfo =
(inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
const xShape = inputs[0].dims;

const outputShape = xShape;
const axis = 2;
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
const components = getMaxComponents(normSize);
const normPackedSize = normSize / components;
const C = xShape[1];
const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components);
const variables = [x, scale, bias, output];
const dataType = x.type.value;
const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
const workgroupSize = 64;
const getShaderSource = (shaderHelper: ShaderHelper) => `

const C: u32 = ${C};
const normSize: u32 = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
const inputShape = [xShape[0], xShape[1], normPackedSize];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank', 'rank'];
guschmue marked this conversation as resolved.
Show resolved Hide resolved
const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: C}, {type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}];
programUniforms.push(
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputs[1].dims),
...createTensorShapeVariables(inputs[2].dims), ...createTensorShapeVariables(inputShape));
guschmue marked this conversation as resolved.
Show resolved Hide resolved

const getShaderSource = (shaderHelper: ShaderHelper) => {
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims.length);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length);
const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
const variables = [x, scale, bias, output];
const dataType = x.type.value;
const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
const workgroupSize = 64;

const uniforms: UniformsArrayType =
[{name: 'C', type: 'u32'}, {name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}];
return `
var<workgroup> meanShared : f32;
var<workgroup> squaredNormShared : f32;
var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
const workgroupSize = ${workgroupSize}u;
${shaderHelper.declareVariables(...variables)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart(workgroupSize)}
let norm = global_idx / workgroupSize;
let batch = norm / C;
let channel = norm % C;
let batch = norm / uniforms.C;
guschmue marked this conversation as resolved.
Show resolved Hide resolved
let channel = norm % uniforms.C;
let localIndex = local_id.x;

// initialize workgroup memory
var initial = ${f32Type}(0);
for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
}
workgroupShared[localIndex] = initial;
Expand All @@ -69,13 +75,13 @@ const createInstanceNormProgramInfo =
workgroupBarrier();
}
if (localIndex == 0) {
meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize);
meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
}
workgroupBarrier();

// reinitialize workgroup memory.
initial = ${f32Type}(0);
for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
initial = initial + deviation * deviation;
}
Expand All @@ -94,23 +100,26 @@ const createInstanceNormProgramInfo =
}
workgroupBarrier();

let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon);
let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
f32Type}(channelShift));
f32Type}(channelShift));
${output.set('batch', 'channel', 'h', 'value')};
}
}`;
};
return {
...metadata,
guschmue marked this conversation as resolved.
Show resolved Hide resolved
shaderCache: {hint: attributes.cacheKey},
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: {hint: `${attributes.epsilon}`, inputDependencies},
guschmue marked this conversation as resolved.
Show resolved Hide resolved
getRunData: () => ({
outputs: [
{dims: outputShape, dataType: inputs[0].dataType},
],
dispatchGroup: {x: normCount}
dispatchGroup: {x: normCount},
programUniforms
}),
getShaderSource,
};
Expand All @@ -120,10 +129,6 @@ const computeMean =
(context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
epsilon: number) => {
const components = getMaxComponents(c);
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);

const WG = 64;
// we will store channel scale and channel shift in [2, components] matrix
// or in vec2 when components == 1
Expand All @@ -133,90 +138,106 @@ const computeMean =
const unitsOfWork = n * c / components;
const wgSize = Math.ceil(h / WG);

const getMeanShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${h * c / components};
const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
guschmue marked this conversation as resolved.
Show resolved Hide resolved
const meanProgramUniforms: ProgramUniform[] = [
{type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)},
{type: 'uint32', data: Math.floor(h * c / components)}
];

const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
return `
${shaderHelper.declareVariables(inputHelper)}
guschmue marked this conversation as resolved.
Show resolved Hide resolved
@group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
@group(0) @binding(2) var<uniform> uniforms: Uniforms;

${shaderHelper.mainStart(WG)}
let currentImageNumber = global_idx / ${WG} / C;
let currentChannelNumber = (global_idx / ${WG}) % C;
let currentImageNumber = global_idx / ${WG} / uniforms.C;
let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
let wgId = global_idx % ${WG};
let wgOffset = wgId * ${wgSize};
if (wgOffset >= H) {
let wgOffset = wgId * uniforms.wg_size;
if (wgOffset >= uniforms.H) {
return;
}
let wgMax = min(wgOffset + ${wgSize}, H);
let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);

let offset = currentImageNumber * imageSize + currentChannelNumber;
let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = wgOffset; i < wgMax; i++) {
let value = ${sumCastType}(input[offset + i * C]);
let value = ${sumCastType}(input[offset + i * uniforms.C]);
sum += value;
squaredSum += value * value;
}
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
}`;
};

const meanValues = context.compute(
{
name: 'InstanceNormComputeMean',
shaderCache: {hint: JSON.stringify({components, n, h, c})},
shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies},
getRunData: () => ({
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float},
],
dispatchGroup: {x: n * c / components},
programUniforms: meanProgramUniforms
}),
getShaderSource: getMeanShaderSource,
},
{inputs: [input], outputs: [-1]})[0];
const getShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${WG * c / components};
const epsilon: f32 = ${epsilon};

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h},
{type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)}
];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank', 'rank'];
guschmue marked this conversation as resolved.
Show resolved Hide resolved
const getShaderSource = (shaderHelper: ShaderHelper) => {
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
return `
@group(0) @binding(0) var<storage, read> input : array<${outputType}>;
@group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
@group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
@group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
@group(0) @binding(4) var<uniform> uniforms: Uniforms;

${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)}
let currentImageNumber = global_idx / C;
let currentChannelNumber = global_idx % C;
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
let currentImageNumber = global_idx / uniforms.C;
let currentChannelNumber = global_idx % uniforms.C;

let offset = currentImageNumber * imageSize;
let offset = currentImageNumber * uniforms.image_size;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < ${WG}; i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];
}
sum = sum / f32(H);
squaredSum = squaredSum / f32(H);
let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon);
sum = sum / f32(uniforms.H);
squaredSum = squaredSum / f32(uniforms.H);
let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon}));
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;

output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
}`;

};
return context.compute(
{
name: 'InstanceNormComputeChannelScaleShift',
shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})},
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
shaderCache: {hint: `${components};${epsilon}`, inputDependencies},
getRunData: () => ({
outputs: [
{dims: [n, c, 2], dataType: DataType.float},
],
dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
},
Expand All @@ -230,50 +251,52 @@ const createInstanceNormNHWCProgramInfo =
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;

const components = getMaxComponents(C);
const outputSize = ShapeUtil.size(outputShape) / components;
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);
const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
guschmue marked this conversation as resolved.
Show resolved Hide resolved

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
// first compute mean
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;

const getShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${H};
const C: u32 = ${C / components};
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);

return `
@group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
@group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
@group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
struct Uniforms {H: u32, C : u32};
@group(0) @binding(3) var<uniform> uniforms: Uniforms;

${shaderHelper.mainStart()}
let currentImageNumber = global_idx / (C * H);
let currentChannelNumber = global_idx % C;
let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
let currentChannelNumber = global_idx % uniforms.C;

let scaleOffset = currentImageNumber * C + currentChannelNumber;
let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
let scale = scaleInput[scaleOffset];
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
}`;
};
context.compute(
{
name: 'InstanceNormalization',
guschmue marked this conversation as resolved.
Show resolved Hide resolved
shaderCache: {hint: `${attributes.cacheKey}`},
shaderCache: {hint: `${components}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
},
{inputs: [inputs[0], channelScaleShift]});
};

export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes =>
createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format});

export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
if (attributes.format === 'NHWC') {
createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
Expand Down
Loading