Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] FP16 LayerNorm, InstanceNorm, SkipLayerNorm #17630

Merged
merged 8 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,69 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
export const createTensorShapeVariables = (dims: readonly number[]):
ProgramUniform[] => [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}];

/**
* A helper function to get maximum vector size for specified data length
* @param size
*/
export const getMaxComponents = (size: number) => {
// we cannot use vec3 type since it has alignment of 16 bytes
if (size % 4 === 0) {
return 4;
} else if (size % 2 === 0) {
return 2;
}

return 1;
};

/**
* A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0)
* @param dataType
* @param components
* @param value
*/
export const fillVector = (dataType = 'f32', components?: number, value = '0') => {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
if (!components || components === 1) {
return `${dataType}(${value})`;
}

return `vec${components}<${dataType}>(${value})`;
};

/**
* A helper function that casts value or vector to f32
* @param dataType
* @param components
* @param value
*/
export const castToF32 = (dataType: string, components: number, value: string) => {
if (dataType === 'f32') {
return value;
}
if (components === 1) {
return `f32(${value})`;
}

return `vec${components}f(${value})`;
};

/**
* A helper function that returns scalar or sums all components of a vector
* @param name
* @param components
*/
export const sumVector = (name: string, components: number) => {
if (components === 4) {
return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`;
} else if (components === 2) {
return `(${name}.x + ${name}.y)`;
} else if (components === 3) {
return `(${name}.x + ${name}.y + ${name}.z)`;
}

return name;
};

/**
* A helper function to get a IndicesHelper for a given input or output.
*
Expand Down
187 changes: 139 additions & 48 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -111,77 +112,167 @@ const createInstanceNormProgramInfo =
};
};

const computeMean =
(context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
epsilon: number) => {
const components = getMaxComponents(c);
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);

const WG = 64;
// we will store channel scale and channel shift in [2, components] matrix
// or in vec2 when components == 1
const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
const unitsOfWork = n * c / components;
const wgSize = Math.ceil(h / WG);

const getMeanShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${h * c / components};

${shaderHelper.declareVariables(inputHelper)}
@group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;

${shaderHelper.mainStart(WG)}
let currentImageNumber = global_idx / ${WG} / C;
let currentChannelNumber = (global_idx / ${WG}) % C;
let wgId = global_idx % ${WG};
let wgOffset = wgId * ${wgSize};
if (wgOffset >= H) {
return;
}
let wgMax = min(wgOffset + ${wgSize}, H);

let offset = currentImageNumber * imageSize + currentChannelNumber;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = wgOffset; i < wgMax; i++) {
let value = ${sumCastType}(input[offset + i * C]);
sum += value;
squaredSum += value * value;
}
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
}`;

const meanValues = context.compute(
{
name: 'InstanceNormComputeMean',
shaderCache: {hint: JSON.stringify({components, n, h, c})},
getRunData: () => ({
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float},
],
dispatchGroup: {x: n * c / components},
}),
getShaderSource: getMeanShaderSource,
},
{inputs: [input], outputs: [-1]})[0];
const getShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
const imageSize: u32 = ${WG * c / components};
const epsilon: f32 = ${epsilon};

@group(0) @binding(0) var<storage, read> input : array<${outputType}>;
@group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
@group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
@group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;

${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)}
let currentImageNumber = global_idx / C;
let currentChannelNumber = global_idx % C;

let offset = currentImageNumber * imageSize;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < ${WG}; i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];
}
sum = sum / f32(H);
squaredSum = squaredSum / f32(H);
let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon);
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;

output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
}`;

return context.compute(
{
name: 'InstanceNormComputeChannelScaleShift',
shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})},
getRunData: () => ({
outputs: [
{dims: [n, c, 2], dataType: DataType.float},
],
dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
}),
getShaderSource,
},
{inputs: [meanValues, scale, bias], outputs: [-1]})[0];
};

const createInstanceNormNHWCProgramInfo =
(inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
(context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;

const components = getMaxComponents(C);
const outputSize = ShapeUtil.size(outputShape) / components;
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
// first compute mean
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);

const normCount = C * N;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const H: u32 = ${H};
const C: u32 = ${C};
const normSizeTyped: ${dataType} = ${H};
const imageSize: u32 = ${H * C};
const epsilon: f32 = ${attributes.epsilon};
const C: u32 = ${C / components};

@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read> scale : array<${dataType}>;
@group(0) @binding(2) var<storage, read> bias : array<${dataType}>;
@group(0) @binding(3) var<storage, read_write> output : array<${dataType}>;
@group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
@group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
@group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;

${shaderHelper.mainStart()}
let currentImageNumber = global_idx / C;
let currentImageNumber = global_idx / (C * H);
let currentChannelNumber = global_idx % C;

// offset is channel num * N
let offset = currentImageNumber * imageSize;
if (offset >= ${outputSize}) { return; }
var mean: ${dataType} = 0;

for (var i: u32 = 0u; i < H; i++) {
mean = mean + x[offset + i * C + currentChannelNumber];
}
mean = mean / normSizeTyped;

var squaredNorm: ${dataType} = 0;
for (var i: u32 = 0u; i < H; i++) {
let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean;
squaredNorm = squaredNorm + deviation * deviation;
}
let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon);
let channelScale = invStdDev * scale[currentChannelNumber];
let channelShift = bias[currentChannelNumber] - mean * channelScale;
for (var i: u32 = 0u; i < H; i++) {
let currentOffset = offset + i * C + currentChannelNumber;
output[currentOffset] = x[currentOffset] * channelScale + channelShift;
}
let scaleOffset = currentImageNumber * C + currentChannelNumber;
let scale = scaleInput[scaleOffset];
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
}`;
return {
...metadata,
shaderCache: {hint: attributes.cacheKey},
getRunData: () => ({
outputs: [
{dims: outputShape, dataType: inputs[0].dataType},
],
dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}
}),
getShaderSource,
};
context.compute(
{
name: 'InstanceNormalization',
shaderCache: {hint: `${attributes.cacheKey}`},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
}),
getShaderSource,
},
{inputs: [inputs[0], channelScaleShift]});
};

export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes =>
createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format});

export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
if (attributes.format === 'NHWC') {
context.compute(createInstanceNormNHWCProgramInfo(context.inputs, attributes));
createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
} else {
context.compute(createInstanceNormProgramInfo(context.inputs, attributes));
}
Expand Down
Loading