From d30cf803063ac1b95f5b6fc2c035c81bad2d9785 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 13 Dec 2024 15:19:27 +0800 Subject: [PATCH 1/6] [webgpu] Optimize matmulnbits with M > 1 --- .../webgpu/quantization/matmul_nbits.cc | 128 +++++++++++++++++- .../webgpu/quantization/matmul_nbits.h | 14 ++ 2 files changed, 139 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be18f820e2747..c234f48223690 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -437,6 +437,103 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t return Status::OK(); } +Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + const uint32_t tile_m = 4; + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. + const uint32_t a_length_per_tile = tile_size / a.NumComponents(); + constexpr uint32_t block_size = 32; + const uint32_t blocks_per_tile = tile_size / block_size; + shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" + " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" + << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" + << " } else {\n" + " return input_a_value_t(0);\n" + " }\n" + "}\n" + << "var sub_a: array," << tile_m << ">;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m << ">;\n"; + shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" + << " let row = workgroup_id.y * " << tile_m << ";\n" + << " let batch = workgroup_id.z;\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" + // Loop over shared dimension. + << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" + << " let a_col_start = tile * " << a_length_per_tile << ";\n" + << " // load one tile A data into shared memory.\n" + << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" + << " let a_col = a_col_start + a_offset;\n"; + for (uint32_t i = 0; i < tile_m; i++) { + shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; + } + shader.MainFunctionBody() << " }\n" + " workgroupBarrier();\n" + // Each thread processes one block. + " let b_row = col + local_id.y;\n" + << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" + " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; + } else { + // The default zero point is 8 for unsigned 4-bit quantization. + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + } + shader.MainFunctionBody() << " var scale = output_element_t(0);\n" + " var b_data = input_b_value_t(0);\n" + << " if (block < n_blocks_per_col) {\n" + << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" + << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" + << " }\n" + << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" + << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + shader.MainFunctionBody() << " let b_value = b_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" + " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" + " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + " let b_dequantized_values = (b_quantized_values - mat2x4("; + for (int i = 0; i < 8; i++) { + shader.MainFunctionBody() << "zero_point"; + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale;\n"; + for (uint32_t i = 0; i < tile_m; i++) { + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; + } + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " workgroupBarrier();\n" + " }\n" + << " if (local_id.y < " << tile_m << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_id.y][local_id.x][b];\n" + " }\n" + " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" + << " }\n" + " }\n"; + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -477,9 +574,34 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context block_size == 32; const bool has_zero_points = zero_points != nullptr; - if (use_block32 && batch_count == 1 && - components_a == 4 && components_b == 4 && - !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { + if (M > 1 && components_a == 4 && block_size == 32) { + MatMulNBitsWithLargeMProgram program{gsl::narrow(components_b), has_zero_points}; + components = 1; + const uint32_t tile_m = 4; + constexpr uint32_t workgroup_size = 64; + const uint32_t workgroup_y = 8; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, + (M + tile_m - 1) / tile_m, + batch_count); + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); + } else if (use_block32 && batch_count == 1 && + components_a == 4 && components_b == 4 && + !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { MatMulNBitsProgramPrefill program; constexpr int32_t tile_size = 16; // subgroup_size here controls how many elements of the hidden dimension we load in a cycle. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 5f785c03f6a5e..9a58473ee0a3b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program { bool use_block32_; }; +class MatMulNBitsWithLargeMProgram final : public Program { + public: + MatMulNBitsWithLargeMProgram(int components_b, bool has_zero_points) : Program{"MatMulNBitsWithLargeM"}, + components_b_{components_b}, + has_zero_points_{has_zero_points} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + private: + int components_b_; + bool has_zero_points_; +}; + class MatMulNBitsProgramPrefill final : public Program { public: MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} { From a349ad4fb6cd39c7112336dcebba19becb0ca4fb Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 16 Dec 2024 10:00:33 +0800 Subject: [PATCH 2/6] Remove MatMulNBitsProgramPrefill --- .../webgpu/quantization/matmul_nbits.cc | 139 ------------------ .../webgpu/quantization/matmul_nbits.h | 14 -- 2 files changed, 153 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index c234f48223690..94c85311d4095 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -322,121 +322,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - // This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion - // M is the number of rows in A and M rows in the output. - // N is the number of columns in B and N columns in the output. - // K is the hidden/shared dimension number of columns in A and K rows in B. - // Note in matmulnbits, B matrix is already transposed, however the following remains true - // for the shader below M describes A, N describes B and K is the hidden/shared dimension. - // K4/K8 are simply K divided by 4 or 8 respectively. - shader.AdditionalImplementation() << R"INIT_SECTION( -// Matrix dimensions and quantization parameters -const TILE_SIZE : u32 = 16u; -const VALUES_PER_VEC4 : u32 = 4u; -const QUANTIZATION_BLOCK_SIZE : u32 = 32; -// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, -// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. -// This uses all 16 lanes on 12th gen intel chips. -const BLOCKS_PER_CYCLE : u32 = 2u; -const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE -const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4; - -//Shared memory -var tile_A : array, TILE_SIZE>; -var tile_B : array, TILE_SIZE>; -var tile_O : array, TILE_SIZE>; - -fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32) -{ - if (a_global >= uniforms.M) { - return; - } - let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id]; - tile_A[slot][parallel_id] = local_A; -} - -fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t -{ - // Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at - // a time or 2 scales at every step. - let scale_offset = vec_step_idx*2; - let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset); - return scales[idx+scale_idx]; -} - -fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32) -{ - if (b_global >= uniforms.N) { - return; - } - let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE)); - let idx:u32 = parallel_id; - if (idx % 2 == 0) - { - // Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time. - // Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight - // elements were holding one element each. For the case of each element holding 8 values, begin - // would become vec_step_idx * 64/8 or vec_step_idx * 8. - var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2); - let b_value = input_b[b_global*uniforms.K8+weight_offset]; - let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); - let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); - tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale; - tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale; - tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale; - tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale; - tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale; - tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale; - tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale; - tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale; - } -} - -fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t -{ - var sum:output_value_t = 0; - for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++) - { - sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]); - } - return sum; -} -)INIT_SECTION"; - - shader.MainFunctionBody() << R"MAIN_FN( - // Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE - // appears to give a performance win on Intel Gen12LP architecture. - // This is likley because of locality of memory access, idy below in this approach - // is the same as subgroup_id or lane id, while idx is the wave_id. - // The work distribution therefore keeps memory accesses close together in - // a single wave in this approach of indexing. - let idx = u32(local_idx / TILE_SIZE); - let idy = u32(local_idx % TILE_SIZE); - let a_global_base = workgroup_id.x * TILE_SIZE; - let b_global_base = workgroup_id.y * TILE_SIZE; - let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE)); - for (var vec_step:u32 = 0; vec_step < step_count; vec_step++) - { - workgroupBarrier(); - loadA(idx, a_global_base+idx, vec_step, idy); - loadB(idx, b_global_base+idx, vec_step, idy); - workgroupBarrier(); - let result = computeDotProduct(idx, idy); - tile_O[idx][idy]+=result; - } - workgroupBarrier(); - if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) { - output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy]; - } -)MAIN_FN"; - return Status::OK(); -} - Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); @@ -599,30 +484,6 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } return context.RunProgram(program); - } else if (use_block32 && batch_count == 1 && - components_a == 4 && components_b == 4 && - !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { - MatMulNBitsProgramPrefill program; - constexpr int32_t tile_size = 16; - // subgroup_size here controls how many elements of the hidden dimension we load in a cycle. - // MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup - // size just helps with optimal lane usage in the shader. - constexpr int32_t subgroup_size = 16; - program.SetWorkgroupSize(tile_size * subgroup_size); - program.SetDispatchGroupSize((M + tile_size - 1) / tile_size, - (N + tile_size - 1) / tile_size, - 1); - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 4)}, - {static_cast(K / 8)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}); - return context.RunProgram(program); } else { // TODO: Support output_number > 1. Some cases are failed when output_number > 1. // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 9a58473ee0a3b..7392403f1e893 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -45,20 +45,6 @@ class MatMulNBitsWithLargeMProgram final : public Program { - public: - MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K4", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}); -}; - class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { From a7a7d9b7e083c8b7b3ecd59d29b7f84d5d375dc9 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 16 Dec 2024 10:46:36 +0800 Subject: [PATCH 3/6] remove components_a limitation --- .../webgpu/quantization/matmul_nbits.cc | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 94c85311d4095..a28fc8063eece 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -113,22 +113,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);\n" - " let a_data1 = vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);\n"; - break; - case 2: - shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1]);\n" - " let a_data1 = vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]);\n"; - break; - case 4: - shader.MainFunctionBody() << " let a_data0 = sub_a[word_offset];\n" - " let a_data1 = sub_a[word_offset + 1];\n"; - break; - default: - break; - } shader.MainFunctionBody() << " let b_value = b_data"; if (components_b_ > 1) { shader.MainFunctionBody() << "[i]"; @@ -144,9 +128,21 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << ", "; } } - shader.MainFunctionBody() << ")) * scale;\n" - " inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n" - << " word_offset += " << 8 / a.NumComponents() << ";\n" + shader.MainFunctionBody() << ")) * scale;\n"; + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" << " }\n" " workgroupBarrier();\n" " }\n" @@ -329,6 +325,8 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); const uint32_t tile_m = 4; + ORT_ENFORCE(tile_m < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); + ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. const uint32_t a_length_per_tile = tile_size / a.NumComponents(); @@ -401,7 +399,19 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co } shader.MainFunctionBody() << ")) * scale;\n"; for (uint32_t i = 0; i < tile_m; i++) { - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } } shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" << " }\n" @@ -459,7 +469,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context block_size == 32; const bool has_zero_points = zero_points != nullptr; - if (M > 1 && components_a == 4 && block_size == 32) { + if (M > 1 && block_size == 32) { MatMulNBitsWithLargeMProgram program{gsl::narrow(components_b), has_zero_points}; components = 1; const uint32_t tile_m = 4; From be81377e4cec9991ef1b7ba100f5f3523017ea73 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 16 Dec 2024 11:11:27 +0800 Subject: [PATCH 4/6] make tile_m as class member --- .../webgpu/quantization/matmul_nbits.cc | 23 +++++++++---------- .../webgpu/quantization/matmul_nbits.h | 8 ++++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index a28fc8063eece..30f8a9c8468ac 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -324,8 +324,7 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - const uint32_t tile_m = 4; - ORT_ENFORCE(tile_m < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); + ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. @@ -339,10 +338,10 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co " return input_a_value_t(0);\n" " }\n" "}\n" - << "var sub_a: array," << tile_m << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m << ">;\n"; + << "var sub_a: array," << tile_m_ << ">;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m << ";\n" + << " let row = workgroup_id.y * " << tile_m_ << ";\n" << " let batch = workgroup_id.z;\n" " let n_blocks_per_col = uniforms.input_b_shape[1];\n" << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" @@ -352,7 +351,7 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co << " // load one tile A data into shared memory.\n" << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" << " let a_col = a_col_start + a_offset;\n"; - for (uint32_t i = 0; i < tile_m; i++) { + for (uint32_t i = 0; i < tile_m_; i++) { shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; } shader.MainFunctionBody() << " }\n" @@ -398,7 +397,7 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co } } shader.MainFunctionBody() << ")) * scale;\n"; - for (uint32_t i = 0; i < tile_m; i++) { + for (uint32_t i = 0; i < tile_m_; i++) { switch (a.NumComponents()) { case 1: shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; @@ -417,7 +416,7 @@ Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) co << " }\n" " workgroupBarrier();\n" " }\n" - << " if (local_id.y < " << tile_m << ") {\n" + << " if (local_id.y < " << tile_m_ << ") {\n" << " var output_value = output_value_t(0);\n" << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" << " output_value += inter_results[local_id.y][local_id.x][b];\n" @@ -470,12 +469,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool has_zero_points = zero_points != nullptr; if (M > 1 && block_size == 32) { - MatMulNBitsWithLargeMProgram program{gsl::narrow(components_b), has_zero_points}; + constexpr uint32_t tile_m = 4; + MatMulNBitsWithLargeMProgram program{tile_m, gsl::narrow(components_b), has_zero_points}; components = 1; - const uint32_t tile_m = 4; constexpr uint32_t workgroup_size = 64; - const uint32_t workgroup_y = 8; - const uint32_t workgroup_x = workgroup_size / workgroup_y; + constexpr uint32_t workgroup_y = 8; + constexpr uint32_t workgroup_x = workgroup_size / workgroup_y; program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, (M + tile_m - 1) / tile_m, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 7392403f1e893..cb2e1895c6b17 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -33,14 +33,16 @@ class MatMulNBitsProgram final : public Program { class MatMulNBitsWithLargeMProgram final : public Program { public: - MatMulNBitsWithLargeMProgram(int components_b, bool has_zero_points) : Program{"MatMulNBitsWithLargeM"}, - components_b_{components_b}, - has_zero_points_{has_zero_points} { + MatMulNBitsWithLargeMProgram(uint32_t tile_m, int components_b, bool has_zero_points) : Program{"MatMulNBitsWithLargeM"}, + tile_m_{tile_m}, + components_b_{components_b}, + has_zero_points_{has_zero_points} { } Status GenerateShaderCode(ShaderHelper& sh) const override; private: + uint32_t tile_m_; int components_b_; bool has_zero_points_; }; From d6277ea1a35109b816ca715d83ad45a432fd42e2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 16 Dec 2024 12:57:27 +0800 Subject: [PATCH 5/6] merge MatMulNBitsWithLargeMProgram to MatMulNBitsProgram --- .../webgpu/quantization/matmul_nbits.cc | 329 +++++++----------- .../webgpu/quantization/matmul_nbits.h | 30 +- 2 files changed, 136 insertions(+), 223 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 30f8a9c8468ac..6bddde8964be6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -38,8 +38,6 @@ std::string QuantizedDataType(int components) { return "array"; } } - -constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16; } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -60,33 +58,59 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - if (use_block32_) { + if ((is_intel_ || tile_m_ > 1) && block_size_ == 32) { const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. const uint32_t a_length_per_tile = tile_size / a.NumComponents(); - constexpr uint32_t block_size = 32; - const uint32_t blocks_per_tile = tile_size / block_size; - shader.AdditionalImplementation() << "var sub_a: array;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; - std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); - shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" - << " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + const uint32_t blocks_per_tile = tile_size / block_size_; + if (tile_m_ == 1) { + shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" + " if (col < uniforms.input_a_shape[2]) {\n" + << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" + << " } else {\n" + " return input_a_value_t(0);\n" + " }\n" + "}\n" + << "var sub_a: array;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; + std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n"; + } else { + ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); + ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); + + shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" + " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" + << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" + << " } else {\n" + " return input_a_value_t(0);\n" + " }\n" + "}\n" + << "var sub_a: array," << tile_m_ << ">;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; + shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" + << " let row = workgroup_id.y * " << tile_m_ << ";\n" + << " let batch = workgroup_id.z;\n"; + } + shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" // Loop over shared dimension. << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" << " let a_col_start = tile * " << a_length_per_tile << ";\n" << " // load one tile A data into shared memory.\n" << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" - << " let a_col = a_col_start + a_offset;\n" - " if (a_col < uniforms.input_a_shape[2]) {\n" - << " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n" - << " } else {\n" - " sub_a[a_offset] = input_a_value_t(0);\n" - " }\n" - " }\n" + << " let a_col = a_col_start + a_offset;\n"; + if (tile_m_ == 1) { + shader.MainFunctionBody() << " sub_a[a_offset] = mm_readA(batch, row, a_col);\n"; + } else { + for (uint32_t i = 0; i < tile_m_; i++) { + shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; + } + } + shader.MainFunctionBody() << " }\n" " workgroupBarrier();\n" // Each thread processes one block. " let b_row = col + local_id.y;\n" @@ -111,7 +135,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" << " }\n" - << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" + << " var word_offset = local_id.x * " << block_size_ / a.NumComponents() << ";\n" << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; shader.MainFunctionBody() << " let b_value = b_data"; if (components_b_ > 1) { @@ -129,32 +153,62 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } shader.MainFunctionBody() << ")) * scale;\n"; - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; + if (tile_m_ == 1) { + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } + } else { + for (uint32_t i = 0; i < tile_m_; i++) { + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; + break; + case 2: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; + break; + case 4: + shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; + break; + default: + break; + } + } } shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" << " }\n" " workgroupBarrier();\n" - " }\n" - << " if (local_idx < " << WorkgroupSizeY() << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_idx][b];\n" - " }\n" - " if (col + local_idx < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" - << " }\n" " }\n"; + if (tile_m_ == 1) { + shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_idx][b];\n" + " }\n" + " if (col + local_idx < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" + << " }\n" + " }\n"; + } else { + shader.MainFunctionBody() << " if (local_id.y < " << tile_m_ << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_id.y][local_id.x][b];\n" + " }\n" + " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" + << " }\n" + " }\n"; + } } else { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); @@ -318,116 +372,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); - const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - - ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); - ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); - const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); - const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. - const uint32_t a_length_per_tile = tile_size / a.NumComponents(); - constexpr uint32_t block_size = 32; - const uint32_t blocks_per_tile = tile_size / block_size; - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_a: array," << tile_m_ << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; - shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m_ << ";\n" - << " let batch = workgroup_id.z;\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" - // Loop over shared dimension. - << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" - << " let a_col_start = tile * " << a_length_per_tile << ";\n" - << " // load one tile A data into shared memory.\n" - << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" - << " let a_col = a_col_start + a_offset;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; - } - shader.MainFunctionBody() << " }\n" - " workgroupBarrier();\n" - // Each thread processes one block. - " let b_row = col + local_id.y;\n" - << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" - " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; - } else { - // The default zero point is 8 for unsigned 4-bit quantization. - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - } - shader.MainFunctionBody() << " var scale = output_element_t(0);\n" - " var b_data = input_b_value_t(0);\n" - << " if (block < n_blocks_per_col) {\n" - << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" - << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" - << " }\n" - << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" - << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data"; - if (components_b_ > 1) { - shader.MainFunctionBody() << "[i]"; - } - shader.MainFunctionBody() << ";\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4("; - for (int i = 0; i < 8; i++) { - shader.MainFunctionBody() << "zero_point"; - if (i < 7) { - shader.MainFunctionBody() << ", "; - } - } - shader.MainFunctionBody() << ")) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; - } - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n" - " workgroupBarrier();\n" - " }\n" - << " if (local_id.y < " << tile_m_ << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_id.y][local_id.x][b];\n" - " }\n" - " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" - << " }\n" - " }\n"; - return Status::OK(); -} - Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -463,14 +407,16 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context uint32_t components = GetMaxComponents(N); // Use block32 for Intel Gen12LP architecture. - const bool use_block32 = context.AdapterInfo().vendor == std::string_view{"intel"} && - context.AdapterInfo().architecture == std::string_view{"gen-12lp"} && - block_size == 32; + const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} && + context.AdapterInfo().architecture == std::string_view{"gen-12lp"}; const bool has_zero_points = zero_points != nullptr; + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + constexpr uint32_t output_number = 1; + const uint32_t tile_m = M > 4 ? 4 : 1; + MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, is_intel}; if (M > 1 && block_size == 32) { - constexpr uint32_t tile_m = 4; - MatMulNBitsWithLargeMProgram program{tile_m, gsl::narrow(components_b), has_zero_points}; components = 1; constexpr uint32_t workgroup_size = 64; constexpr uint32_t workgroup_y = 8; @@ -479,54 +425,33 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, (M + tile_m - 1) / tile_m, batch_count); - - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; - - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}); - if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); - } - return context.RunProgram(program); + } else if (is_intel && block_size == 32) { + components = 1; + constexpr uint32_t workgroup_size = 128; + const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 + : 1; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize(data_size / components / workgroup_y); } else { - // TODO: Support output_number > 1. Some cases are failed when output_number > 1. - // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; - constexpr uint32_t output_number = 1; - MatMulNBitsProgram program{output_number, gsl::narrow(components_b), has_zero_points, use_block32}; - - if (use_block32) { - components = 1; - constexpr uint32_t workgroup_size = 128; - const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 - : 1; - const uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize(data_size / components / workgroup_y); - } else { - program.SetDispatchGroupSize(data_size / components / output_number); - } - - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; + program.SetDispatchGroupSize(data_size / components / output_number); + } - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) - .AddUniformVariable({block_size}) - .CacheHint(std::to_string(output_number)); - if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); - } - return context.RunProgram(program); + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } + return context.RunProgram(program); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index cb2e1895c6b17..8a4626083419c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -14,11 +14,13 @@ using namespace onnxruntime::webgpu; class MatMulNBitsProgram final : public Program { public: - MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"}, - output_number_{output_number}, - components_b_{components_b}, - has_zero_points_{has_zero_points}, - use_block32_{use_block32} { + MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool is_intel) : Program{"MatMulNBits"}, + output_number_{output_number}, + block_size_{block_size}, + tile_m_{tile_m}, + components_b_{components_b}, + has_zero_points_{has_zero_points}, + is_intel_{is_intel} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -26,25 +28,11 @@ class MatMulNBitsProgram final : public Program { private: uint32_t output_number_; - int components_b_; - bool has_zero_points_; - bool use_block32_; -}; - -class MatMulNBitsWithLargeMProgram final : public Program { - public: - MatMulNBitsWithLargeMProgram(uint32_t tile_m, int components_b, bool has_zero_points) : Program{"MatMulNBitsWithLargeM"}, - tile_m_{tile_m}, - components_b_{components_b}, - has_zero_points_{has_zero_points} { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - private: + uint32_t block_size_; uint32_t tile_m_; int components_b_; bool has_zero_points_; + bool is_intel_; }; class MatMulNBits final : public WebGpuKernel { From ca8ef7abff225ea0f9ed4964a29b24c65080bffa Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 16 Dec 2024 14:19:00 +0800 Subject: [PATCH 6/6] set tile M threshold --- .../webgpu/quantization/matmul_nbits.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 6bddde8964be6..9a49adf347a29 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -38,6 +38,8 @@ std::string QuantizedDataType(int components) { return "array"; } } + +constexpr unsigned int kMinMForTileOptimization = 4; } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -406,17 +408,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t components_b = GetMaxComponents(blob_size_in_words); uint32_t components = GetMaxComponents(N); - // Use block32 for Intel Gen12LP architecture. const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} && context.AdapterInfo().architecture == std::string_view{"gen-12lp"}; const bool has_zero_points = zero_points != nullptr; // TODO: Support output_number > 1. Some cases are failed when output_number > 1. - // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; constexpr uint32_t output_number = 1; - const uint32_t tile_m = M > 4 ? 4 : 1; + const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, is_intel}; - if (M > 1 && block_size == 32) { + if (M > kMinMForTileOptimization && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 64; constexpr uint32_t workgroup_y = 8; @@ -425,6 +425,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, (M + tile_m - 1) / tile_m, batch_count); + program.CacheHint("T_M" + std::to_string(tile_m)); } else if (is_intel && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 128; @@ -433,8 +434,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t workgroup_x = workgroup_size / workgroup_y; program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); program.SetDispatchGroupSize(data_size / components / workgroup_y); + program.CacheHint("T_M" + std::to_string(tile_m)); } else { program.SetDispatchGroupSize(data_size / components / output_number); + program.CacheHint("O_N" + std::to_string(output_number)); } TensorShape reshaped_a_shape{batch_count, M, K / components_a}; @@ -446,8 +449,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, {scales, ProgramTensorMetadataDependency::None}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) - .AddUniformVariable({block_size}) - .CacheHint(std::to_string(output_number)); + .AddUniformVariable({block_size}); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); }