Skip to content

Commit

Permalink
[js/web] JSEP LayerNormalization and InstanceNormalizations kernels (#…
Browse files Browse the repository at this point in the history
…16830)

### Description
Added two kernels for Layer and Instance norm

Also added maximum limits for `maxBufferSize` when requesting GPU device
as by default it's limited to 256mb and it fails allocating 600mb buffer
while running fp32 StableDiffusion weights.


### Motivation and Context
These two are used in StableDiffusion and many other networks
  • Loading branch information
dakenf authored and jchen351 committed Aug 12, 2023
1 parent 0efded5 commit 2b814a4
Show file tree
Hide file tree
Showing 15 changed files with 490 additions and 38 deletions.
2 changes: 2 additions & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Do not modify directly.*
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
| LeakyRelu | ai.onnx(6-15,16+) | |
| MatMul | ai.onnx(1-12,13+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ export class WebGpuBackend {
maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
maxComputeWorkgroupsPerDimension: adapter.limits.maxComputeWorkgroupsPerDimension,
maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
}
maxBufferSize: adapter.limits.maxBufferSize,
maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup,
maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,
maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ,
},
};
// WebGPU Spec: Timestamp Queries Inside Passes
// https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TensorViewImpl implements TensorView {
class ComputeContextImpl implements ComputeContext {
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
get kernelCustomData(): {[key: string]: unknown} {
return this.backend.currentKernelCustomData;
}
Expand All @@ -71,6 +72,7 @@ class ComputeContextImpl implements ComputeContext {
let dataIndex = (contextDataOffset >> 2);
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];

Expand Down
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gelu} from './ops/gelu';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import * as pool from './ops/pool';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
Expand Down Expand Up @@ -58,6 +60,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
Expand Down
172 changes: 172 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';

export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
format: 'NHWC'|'NCHW';
}

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 3) {
throw new Error('instanceNorm requires 3 inputs.');
}

if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('inputs should be float type');
}
};

const createInstanceNormProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
const xShape = inputs[0].dims;
const scale = inputs[1];
const bias = inputs[2];

const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const axis = 2;
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
const C = xShape[1];

const scaleSize = ShapeUtil.size(scale.dims);
const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
throw new Error(`Size of X.shape()[axis:] == ${normSize}.
Size of scale and bias (if provided) must match this.
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
}

const dataType = tensorTypeToWsglType(inputs[0].dataType);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const C: u32 = ${C};
const normSize: u32 = ${normSize};
const normSizeTyped: ${dataType} = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read> scale : array<${dataType}>;
@group(0) @binding(2) var<storage, read> bias : array<${dataType}>;
@group(0) @binding(3) var<storage, read_write> output : array<${dataType}>;
${shaderHelper.mainStart()}
let offset = global_idx * normSize;
if (offset + normSize >= ${outputSize}) { return; }
var mean: ${dataType} = 0;
for (var h: u32 = 0u; h < normSize; h++) {
mean = mean + x[h + offset];
}
mean = mean / normSizeTyped;
var squaredNorm: ${dataType} = 0;
for (var h: u32 = 0u; h < normSize; h++) {
let deviation: f32 = x[h + offset] - mean;
squaredNorm = squaredNorm + deviation * deviation;
}
let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon);
let channelScale = invStdDev * scale[global_idx % C];
let channelShift = bias[global_idx % C] - mean * channelScale;
for (var j: u32 = 0; j < normSize; j++) {
output[j + offset] = x[j + offset] * channelScale + channelShift;
}
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)})
};
};

const createInstanceNormNHWCProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;

const dataType = tensorTypeToWsglType(inputs[0].dataType);

const normCount = C * N;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const H: u32 = ${H};
const C: u32 = ${C};
const normSizeTyped: ${dataType} = ${H};
const imageSize: u32 = ${H * C};
const epsilon: f32 = ${attributes.epsilon};
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read> scale : array<${dataType}>;
@group(0) @binding(2) var<storage, read> bias : array<${dataType}>;
@group(0) @binding(3) var<storage, read_write> output : array<${dataType}>;
${shaderHelper.mainStart()}
let currentImageNumber = global_idx / C;
let currentChannelNumber = global_idx % C;
// offset is channel num * N
let offset = currentImageNumber * imageSize;
if (offset >= ${outputSize}) { return; }
var mean: ${dataType} = 0;
for (var i: u32 = 0u; i < H; i++) {
mean = mean + x[offset + i * C + currentChannelNumber];
}
mean = mean / normSizeTyped;
var squaredNorm: ${dataType} = 0;
for (var i: u32 = 0u; i < H; i++) {
let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean;
squaredNorm = squaredNorm + deviation * deviation;
}
let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon);
let channelScale = invStdDev * scale[currentChannelNumber];
let channelShift = bias[currentChannelNumber] - mean * channelScale;
for (var i: u32 = 0u; i < H; i++) {
let currentOffset = offset + i * C + currentChannelNumber;
output[currentOffset] = x[currentOffset] * channelScale + channelShift;
}
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)})
};
};

export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes =>
createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format});

export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
validateInputs(context.inputs);

const metadata = {
name: 'InstanceNormalization',
inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey,
};

if (attributes.format === 'NHWC') {
context.compute(createInstanceNormNHWCProgramInfo(metadata, context.inputs, attributes));
} else {
context.compute(createInstanceNormProgramInfo(metadata, context.inputs, attributes));
}
};
126 changes: 126 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';

export interface LayerNormAttributes extends AttributeWithCacheKey {
axis: number;
epsilon: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length <= 2) {
throw new Error('layerNorm requires at least 2 inputs.');
}

if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('inputs should be float type');
}
};

const createLayerNormProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: LayerNormAttributes, outputCount: number):
ProgramInfo => {
const xShape = inputs[0].dims;
const scale = inputs[1];
const bias = inputs[2];

const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length);
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);

const scaleSize = ShapeUtil.size(scale.dims);
const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
throw new Error(`Size of X.shape()[axis:] == ${normSize}.
Size of scale and bias (if provided) must match this.
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
}

const meanInvStdDevDim = [];
for (let i = 0; i < xShape.length; ++i) {
if (i < axis) {
meanInvStdDevDim.push(xShape[i]);
} else {
meanInvStdDevDim.push(1);
}
}

const dataType = tensorTypeToWsglType(inputs[0].dataType);

const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const normSize: u32 = ${normSize};
const normSizeTyped: ${dataType} = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read> scale : array<${dataType}>;
${bias ? `@group(0) @binding(2) var<storage, read> bias : array<${dataType}>;` : ''}
@group(0) @binding(3) var<storage, read_write> output : array<${dataType}>;
${hasMeanDataOutput ? `@group(0) @binding(4) var<storage, read_write> meanDataOutput : array<${dataType}>` : ''};
${hasInvStdOutput ? `@group(0) @binding(5) var<storage, read_write> invStdOutput : array<${dataType}>` : ''};
${shaderHelper.mainStart()}
let offset = global_idx * normSize;
if (offset >= ${outputSize}) { return; }
var mean: ${dataType} = 0;
var meanSquare: ${dataType} = 0;
for (var h: u32 = 0u; h < normSize; h++) {
mean = mean + x[h + offset];
meanSquare = meanSquare + x[h + offset] * x[h + offset];
}
mean = mean / normSizeTyped;
meanSquare = sqrt(meanSquare / normSizeTyped - mean * mean + epsilon);
for (var j: u32 = 0; j < normSize; j++) {
output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''};
}
${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''};
${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''};
}`;
const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}];
if (hasMeanDataOutput) {
outputs.push(
{dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
);
}
if (hasInvStdOutput) {
outputs.push(
{dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
);
}

return {
...metadata,
outputs,
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)})
};
};

export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes =>
createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon});

export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => {
validateInputs(context.inputs);

const metadata = {
name: 'LayerNormalization',
inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey + context.outputCount.toString(10) + context.inputs.length.toString(10),
};

context.compute(createLayerNormProgramInfo(metadata, context.inputs, attributes, context.outputCount));
};
5 changes: 5 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ export interface ComputeContext {
*/
readonly customDataBuffer: Uint8Array;

/**
* a number of outputs for the node
*/
readonly outputCount: number;

compute(program: ProgramInfoLoader|ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping):
TensorView[];
output(index: number, dims: readonly number[]): number;
Expand Down
16 changes: 16 additions & 0 deletions js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
throw new Error(`unsupported logging level: ${logLevel}`);
}
};

export const tensorTypeToWsglType = (type: DataType) => {
switch (type) {
case DataType.float:
return 'f32';
// TODO: enable after "shader-f16" WSGL extension release
// case DataType.float16:
// return 'f16';
case DataType.int32:
return 'i32';
case DataType.uint32:
return 'u32';
default:
throw new Error(`Unsupported type: ${type}`);
}
};
Loading

0 comments on commit 2b814a4

Please sign in to comment.