Skip to content

Commit

Permalink
Improve Decode Shape Performance for AMD FP8 (pytorch#2658)
Browse files Browse the repository at this point in the history
Summary:

Add tuning config for decode workloads. Improves performance substantially for shapes with small M.

Differential Revision: D58031289
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jun 3, 2024
1 parent 9bb687c commit 8e07ad2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
6 changes: 3 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def main(args: Any) -> None:
benchmark_results = []

# Test over a bunch of shapes.
M = [128, 2048, 2304, 13312, 16032, 16384]
N = [128, 2304, 4096, 8192, 13312]
K = [128, 2048, 2304, 6656, 13312, 16384]
M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 16384]
N = [2304, 4096, 8192, 13312]
K = [2048, 2304, 6656, 13312, 16384]

for m in M:
for n in N:
Expand Down
88 changes: 71 additions & 17 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions.hip
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ std::tuple<KernelMode, bool> get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto K = XQ.size(1);
auto N = WQ.size(0);
// Use small kernel when input matrices are small.
bool use_small_kernel = (M <= 512 && N <= 512);
bool use_small_kernel = (M <= 512 && N <= 512) || (M <= 128) || (N <= 128);
// For larger workloads, specialize to large gemm.
bool use_large_kernel =
((M >= 4096 && N >= 4096) || (M >= 8192 && N >= 2048) ||
Expand Down Expand Up @@ -268,7 +268,8 @@ template <
int KBLOCK,
int MPER_WAVE,
int NPER_WAVE,
bool PADDING = false>
bool PADDING = false,
bool TINY = false>
at::Tensor f8f8bf16_rowwise_impl(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -317,8 +318,16 @@ at::Tensor f8f8bf16_rowwise_impl(

// Define derivative constants based on template parameters.
static constexpr int BLOCK_CLUSTER = BLOCK_SIZE / 4;
static constexpr int CBLOCK_N = NBLOCK / 16;
static constexpr int CBLOCK_M = BLOCK_SIZE / CBLOCK_N;
static constexpr int CBLOCK_N = TINY ? 4 : NBLOCK / 16;
static constexpr int CBLOCK_M = TINY ? 16 : BLOCK_SIZE / CBLOCK_N;

// A few modes change for tiny kernels.
static constexpr int MPER_XDL = TINY ? 16 : 32;
static constexpr int NPER_XDL = TINY ? 16 : 32;
static constexpr auto LOOP_SCHED = TINY ? ck::BlockGemmPipelineScheduler::Intrawave : ck::BlockGemmPipelineScheduler::Interwave;
using ABLOCK_TRANSFER = std::conditional_t<TINY, S<BLOCK_CLUSTER, 4, 1>, S<4, BLOCK_CLUSTER, 1>>;
using BBLOCK_TRANSFER = std::conditional_t<TINY, S<BLOCK_CLUSTER, 4, 1>, S<4, BLOCK_CLUSTER, 1>>;
using CBLOCK_TRANSFER = std::conditional_t<TINY, S<4, 4, 1>, S<8, 8, 1>>;

using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
Expand All @@ -342,18 +351,18 @@ at::Tensor f8f8bf16_rowwise_impl(
KBLOCK, // K per Block
16, // AK1
16, // BK1
32, // M per Xdl
32, // N per Xdl
MPER_XDL, // M per Xdl
NPER_XDL, // N per Xdl
MPER_WAVE, // Mxdl per Wave
NPER_WAVE, // Nxdl per Wave
S<4, 64, 1>,
ABLOCK_TRANSFER,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
0,
S<4, 64, 1>,
BBLOCK_TRANSFER,
S<1, 0, 2>,
S<1, 0, 2>,
2,
Expand All @@ -363,8 +372,8 @@ at::Tensor f8f8bf16_rowwise_impl(
1,
1,
S<1, CBLOCK_M, 1, CBLOCK_N>,
S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Interwave,
CBLOCK_TRANSFER,
LOOP_SCHED,
ck::BlockGemmPipelineVersion::v1,
ComputeType>;

Expand Down Expand Up @@ -404,6 +413,43 @@ at::Tensor f8f8bf16_rowwise_impl(
return Y;
}

enum class RowKernelMode { Tiny, Small, Large, Default };

std::tuple<RowKernelMode, bool> get_row_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto M = XQ.size(0);
auto K = XQ.size(1);
auto N = WQ.size(0);
// Tiny kernels should be used when M and N are so small we need to load mostly from K.
// We also find this kernel is good for many shapes with small M up until a certain N.
bool use_tiny_kernel = (M <= 32 || N <= 32) || (M <= 128 && N <= 8192);
// For other cases where M is small but N is large, we have a specialized kernel.
bool use_small_kernel = (M <= 128);
// Larger workloads can load big chunks.
bool use_large_kernel =
((M >= 4096 && N >= 4096) || (M >= 8192 && N >= 2048) ||
(N >= 8192 && M >= 2048) || (K >= 8192 && M >= 2048 && N >= 2048));
// Set padding based on the selected mode.
bool use_pad;
if (use_tiny_kernel) {
// Tiny kernels use chunks of 16 in M and N and 256 in K.
// If any dimension is smaller than that we need to pad.
use_pad = M < 16 || N < 16 || K < 256;
return {RowKernelMode::Tiny, use_pad};
} else if (use_small_kernel) {
// Small kernels load chunks of 32 in M, 128 in N and 128 in K.
use_pad = M < 32 || N < 128 || K < 128;
return {RowKernelMode::Small, use_pad};
} else if (use_large_kernel) {
// Large kernels load chunks of 256 in M, 128 in K and 64 in K.
use_pad = M < 256 || N < 128 || K < 64;
return {RowKernelMode::Large, use_pad};
} else {
// Default kernel loads chunks of 128 in M and N and 64 in K.
use_pad = M < 128 || N < 128 || K < 64;
return {RowKernelMode::Default, use_pad};
}
}

at::Tensor f8f8bf16_rowwise(
at::Tensor XQ,
at::Tensor WQ,
Expand All @@ -419,18 +465,26 @@ at::Tensor f8f8bf16_rowwise(
TORCH_CHECK((x_scale.dtype() == at::kFloat) && (w_scale.dtype() == at::kFloat), "Scales must be float32.");
TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum.");
TORCH_CHECK(!(bias.has_value()), "AMD does not yet support bias.");
auto [kernel, pad] = get_kernel_mode(XQ, WQ);
auto [kernel, pad] = get_row_kernel_mode(XQ, WQ);
if (pad) {
if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, true>(XQ, WQ, x_scale, w_scale);
if (kernel == RowKernelMode::Tiny) {
return f8f8bf16_rowwise_impl<64, 16, 16, 256, 1, 1, true, true>(XQ, WQ, x_scale, w_scale);
} else if (kernel == RowKernelMode::Small) {
return f8f8bf16_rowwise_impl<128, 32, 128, 128, 1, 2, true, false>(XQ, WQ, x_scale, w_scale);
} else if (kernel == RowKernelMode::Large) {
return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, true, false>(XQ, WQ, x_scale, w_scale);
} else {
return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, true>(XQ, WQ, x_scale, w_scale);
return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, true, false>(XQ, WQ, x_scale, w_scale);
}
} else {
if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, false>(XQ, WQ, x_scale, w_scale);
if (kernel == RowKernelMode::Tiny) {
return f8f8bf16_rowwise_impl<64, 16, 16, 256, 1, 1, false, true>(XQ, WQ, x_scale, w_scale);
} else if (kernel == RowKernelMode::Small) {
return f8f8bf16_rowwise_impl<128, 32, 128, 128, 1, 2, true, false>(XQ, WQ, x_scale, w_scale);
} else if (kernel == RowKernelMode::Large) {
return f8f8bf16_rowwise_impl<256, 256, 128, 64, 4, 2, false, false>(XQ, WQ, x_scale, w_scale);
} else {
return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, false>(XQ, WQ, x_scale, w_scale);
return f8f8bf16_rowwise_impl<256, 128, 128, 64, 2, 2, false, false>(XQ, WQ, x_scale, w_scale);
}
}
}
Expand Down

0 comments on commit 8e07ad2

Please sign in to comment.