From 1df991198a0bed8284a5a58b74fcecea576f10bd Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Mar 2024 10:33:14 -0700 Subject: [PATCH] [JS/WebGPU] Optimize MatMulNBits (#19852) ### Description Use vec<2> or vec<4>, operands in MatMulNBits ### Motivation and Context Improve performance --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 208 ++++++++++++------ js/web/test/data/ops/matmulnbits.jsonc | 57 +++++ 2 files changed, 194 insertions(+), 71 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index ead7635cf3ac4..9bf5e4066139d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; // TODO support quantization bits not equal to 4 export interface MatMulNBitsAttributes extends AttributeWithCacheKey { @@ -51,29 +51,38 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt export const createMatMulNBitsProgramInfo = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { - const a = inputs[0]; - const b = inputs[1]; - const scales = inputs[2]; - const aRank = a.dims.length; - const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); - const outputSize = ShapeUtil.size(outputShape); - - + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n); + const m = inputShape[aRank - 2]; + const blobSize = attributes.blockSize / 8 * attributes.bits; + const blobSizeInWords = blobSize / 4; + const outputNumber = getMaxComponents(m); + const components = getMaxComponents(attributes.n); + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} ]; - programUniforms.push(...createTensorShapeVariables(a.dims)); - programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); - programUniforms.push(...createTensorShapeVariables(scales.dims)); + const aShape = inputShape.slice(); + aShape.splice(-1, 1, attributes.k / aComponents); + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(aShape)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); if (inputs.length === 4) { programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const oShape = outputShape.slice(); + oShape.splice(-1, 1, attributes.n / components); + programUniforms.push(...createTensorShapeVariables(oShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); - const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); const inputVariables = [a, b, scales]; const zeroPoints = @@ -81,86 +90,143 @@ export const createMatMulNBitsProgramInfo = if (zeroPoints) { inputVariables.push(zeroPoints); } - const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} ]; const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); - const blobSize = attributes.blockSize / 8 * attributes.bits; - const wordPerBlob = blobSize / 4; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return ` - fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ - var result = array<${dataType}, 8>(); + + const qDqDataType = (() => { + switch (aComponents) { + case 1: + return `array<${dataType}, 8>`; + case 2: + return `mat4x2<${dataType}>`; + case 4: + return `mat2x4<${dataType}>`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })(); + + const dequantizeImpl = ` + fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} { + ${(() => { + if (aComponents === 1) { + return `var dequantized = ${qDqDataType}(${ + Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')}); + return dequantized;`; + } else { + return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')}); + return (quantized - zero_points) * scale;`; + } + })()} + }`; + const ortUnpack8x4snormImpl = ` + fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} { + var quantized: ${qDqDataType}; var offset: u32 = 0; let count: u32 = 4; for (var i: u32 = 0; i < 8u; i++) { - result[i] = ${dataType}(extractBits(value, offset, count)); + var result = ${dataType}(extractBits(value, offset, count)); + ${(() => { + switch (aComponents) { + case 1: + return 'quantized[i] = result;'; + case 2: + return 'quantized[i / 2][i % 2] = result;'; + case 4: + return 'quantized[i / 4][i % 4] = result;'; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })()} offset += count; } - return result; - } + return quantized; + }`; + + const updateZeroPointIndex = zeroPoints ? ` + zero_point_offset += 4; + if (zero_point_offset == 32) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + }` : + ''; + + return ` + ${dequantizeImpl}; + ${ortUnpack8x4snormImpl}; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var value: ${dataType} = 0.0; - let output_indices = ${output.offsetToIndices('global_idx')}; - var a_indices: ${a.type.indices} = output_indices; + var output_values: array<${output.type.value}, ${outputNumber}>; + var output_indices = ${output.offsetToIndices('global_idx')}; var n = ${output.indicesGet('output_indices', aRank - 1)}; + var m = ${output.indicesGet('output_indices', aRank - 2)}; + var a_indices: ${a.type.indices} = output_indices; // Two zero points are packed into one byte because uniforms.bits <= 4. // zero_point_offset is either 0 or 4. It is bit offset within one byte. // TODO support zero_point_offset for bits > 4 ${ zeroPoints ? ` - var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; - var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = 0;` : + var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : ''} - var scale_idex = n * ${nBlocksPerCol}; + var scale_index = n * ${nBlocksPerCol * components}; var b_indices: ${b.type.indices}; - ${b.indicesSet('b_indices', '0', 'n')}; - var block_offset: u32 = 0; - for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { - // The scale and zero points are computed per block. - let scale = ${scales.getByOffset('scale_idex')}; - // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point: ${dataType} = ${ - zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; - ${b.indicesSet('b_indices', '1', 'block')}; - var word_offset: u32 = block_offset; - for (var word: u32 = 0; word < ${wordPerBlob}; word++) { - ${b.indicesSet('b_indices', '2', 'word')}; - let b_value = ${b.getByIndices('b_indices')}; - let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); - // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 - var offset: u32 = word_offset; - for (var i: u32 = 0; i < 8; i++) { - ${a.indicesSet('a_indices', aRank - 1, 'offset')}; - let a_value = ${a.getByIndices('a_indices')}; - let b_quantized_value = b_quantized_values[i]; - let b_dequantized_value = (b_quantized_value - zero_point) * scale; - value += a_value * b_dequantized_value; - offset++; + for (var c: u32 = 0; c < ${components}; c++) { + ${b.indicesSet('b_indices', '0', `n * ${components} + c`)}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_data = ${b.getByIndices('b_indices')}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; + let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value); + let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var j: u32 = 0; j < 8/${aComponents}; j++) { + ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)}; + for (var k: u32 = 0; k < ${outputNumber}u; k++) { + ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)}; + let a_data = ${a.getByIndices('a_indices')}; + output_values[k]${components > 1 ? '[c]' : ''} += ${ + aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'}; + } + offset += ${aComponents}; + } + word_offset += 8; + } } - word_offset += 8; + scale_index++; + ${updateZeroPointIndex} + block_offset += uniforms.block_size; } - scale_idex++; + // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. ${ - zeroPoints ? ` - if (zero_point_offset == 28) { - zero_point_offset = 0; - zero_point_index++; - zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; - } else { - zero_point_offset += 4; - }` : + zeroPoints ? `if (zero_point_offset % 8 > 0) { + ${updateZeroPointIndex} + }` : ''} - block_offset += uniforms.block_size; - } - ${output.setByOffset('global_idx', 'value')}; - } - `; + } + for (var k: u32 = 0u; k < ${outputNumber}u; k++) { + ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)}; + ${output.setByIndices('output_indices', 'output_values[k]')} + } + }`; }; return { name: 'MatMulNBits', @@ -168,7 +234,7 @@ export const createMatMulNBitsProgramInfo = {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64)}, + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms }), getShaderSource diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc index c57c431afb3ce..175be78cc0818 100644 --- a/js/web/test/data/ops/matmulnbits.jsonc +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -1,4 +1,61 @@ [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, { "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", "operator": "MatMulNBits",