Skip to content

Commit

Permalink
set tile M threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Dec 16, 2024
1 parent d6277ea commit ca8ef7a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ std::string QuantizedDataType(int components) {
return "array<output_element_t, 8>";
}
}

constexpr unsigned int kMinMForTileOptimization = 4;
} // namespace

ONNX_OPERATOR_KERNEL_EX(
Expand Down Expand Up @@ -406,17 +408,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
uint32_t components = GetMaxComponents(N);

// Use block32 for Intel Gen12LP architecture.
const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} &&
context.AdapterInfo().architecture == std::string_view{"gen-12lp"};
const bool has_zero_points = zero_points != nullptr;

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.

Check warning on line 415 in onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:415: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
const uint32_t tile_m = M > 4 ? 4 : 1;
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points, is_intel};
if (M > 1 && block_size == 32) {
if (M > kMinMForTileOptimization && block_size == 32) {
components = 1;
constexpr uint32_t workgroup_size = 64;
constexpr uint32_t workgroup_y = 8;
Expand All @@ -425,6 +425,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y,
(M + tile_m - 1) / tile_m,
batch_count);
program.CacheHint("T_M" + std::to_string(tile_m));
} else if (is_intel && block_size == 32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
Expand All @@ -433,8 +434,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
program.CacheHint("T_M" + std::to_string(tile_m));
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
program.CacheHint("O_N" + std::to_string(output_number));
}

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
Expand All @@ -446,8 +449,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
.AddUniformVariable({block_size});
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
Expand Down

0 comments on commit ca8ef7a

Please sign in to comment.