Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Use DataType as uniform cpu type #19281

Merged
merged 9 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {tensorDataTypeEnumToString} from '../wasm-common';
import {DataType, tensorDataTypeEnumToString} from '../wasm-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
Expand Down Expand Up @@ -428,10 +428,10 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const sizeOfElement = v.type === 'float16' ? 2 : 4;
const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
if (v.type === DataType.float16) {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
Expand All @@ -445,7 +445,7 @@ export class WebGpuBackend {
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});
Expand All @@ -458,15 +458,17 @@ export class WebGpuBackend {
programUniforms.forEach((v, i) => {
const offset = offsets[i];
const data = typeof v.data === 'number' ? [v.data] : v.data;
if (v.type === 'int32') {
if (v.type === DataType.int32) {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
} else if (v.type === DataType.float16) {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
} else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);
} else {
throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
}
});

Expand Down
7 changes: 4 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo =
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo =
];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides},
{type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims},
{type: DataType.int32, data: pads}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
Expand Down Expand Up @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo =
const outputChannelsPerGroup = wShape[1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
{type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides},
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
];
if (hasBias) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -447,8 +448,10 @@ export const createMatmulProgramInfo =
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const bRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const programUniforms: ProgramUniform[] = [
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}
];
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
Expand Down
30 changes: 15 additions & 15 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {tensorDataTypeEnumToString} from '../../../wasm-common';
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';

Expand Down Expand Up @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
WG = Math.ceil(dComp / 8);
}
const elementsPerWG = Math.ceil(d / components / WG);
const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] =
[{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}];
const programUniforms: ProgramUniform[] = [
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
{type: DataType.uint32, data: elementsPerWG}
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -336,11 +337,10 @@ const computeAttentionProbs =
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};
const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize},
{type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength},
{type: tensorDataType, data: alpha}
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: parameters.totalSequenceLength},
{type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha}
];

const inputs = [q, key];
Expand Down Expand Up @@ -430,9 +430,9 @@ const computeVxAttentionScore =
z: params.batchSize * params.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength},
{type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads},
{type: 'uint32', data: params.vHiddenSize}
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
};
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N},
{type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize},
{type: 'uint32', data: parameters.hiddenSize},
{type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
{type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N},
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize},
{type: DataType.uint32, data: parameters.hiddenSize},
{type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {env} from 'onnxruntime-common';

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
...createTensorShapeVariables(yShape),
] :
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
],
}),
};
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ const createBinaryOpProgramInfo =
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: [
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
return typeof mappedType === 'string' ? mappedType : mappedType[1];
};

export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] =>
dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}];
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
[] :
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];

/**
* A helper function to get maximum vector size for specified data length
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
Expand Down
13 changes: 8 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
Expand Down Expand Up @@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo =
const outputSize = ShapeUtil.size(outputShape);

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations},
{type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]},
{type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.uint32, data: outputChannelsPerGroup}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down Expand Up @@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo =
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}
{type: DataType.uint32, data: outputSize},
{type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]},
{type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const createCumsumProgramInfo =
outputs: [{dims: inputShape, dataType: inputType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: 'uint32', data: outputSize}, {type: 'int32', data: axis},
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
]

Expand Down
7 changes: 5 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -272,8 +273,10 @@ const createEinsumProgramInfo =
// filter is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] =
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});
.map(
(symbol) =>
({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: DataType.uint32, data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
};

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
];
return {
Expand Down
7 changes: 5 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {MAX_CLIP, MIN_CLIP} from '../../util';
import {ProgramUniform} from '../types';

Expand Down Expand Up @@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v
export const appendActivationUniformsData =
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
if (attributes.activation === 'Clip') {
programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
programUniform.push(
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
} else if (attributes.activation === 'HardSigmoid') {
programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!});
programUniform.push(
{type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!});
}
};

Expand Down
Loading
Loading