Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust fp8 CK GEMM heurstic #2912

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def bench_with_rotating_buffer(self, fn, args):

# torch.cuda.get_device_properties does not have L2 cache size,
# so hard code an overapproximation of L2 cache size to ensure L2 cache flush
total_buffer_size = 16 * 1024 * 1024
total_buffer_size = 10000 * 1024 * 1024

# Use pickle to serialize model input to estimate total sizes of input
input_sizes = len(pickle.dumps(args))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,79 @@ fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave
at::Tensor w_scale,
at::Tensor Y) {
// This kernel works well for many medium to large shapes.
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);

int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);

bool mnpad = (M % 224 != 0) || (N % 256 != 0);
bool kpad = K % 128 != 0;

if (kpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else if (mnpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,55 @@ fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 256 != 0) || (N % 224 != 0) || (K % 128 != 0);

// This kernel seems optimal in the most purely compute bound tasks.
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ template <
ck::BlockGemmPipelineScheduler LOOP_SCHED,
ck::BlockGemmPipelineVersion PIPELINE_VERSION,
ck::tensor_operation::device::GemmSpecialization GEMM_SPEC =
ck::tensor_operation::device::GemmSpecialization::MNKPadding>
ck::tensor_operation::device::GemmSpecialization::MNPadding>
using DeviceGemmHelper =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout,
Expand Down Expand Up @@ -175,8 +175,6 @@ at::Tensor f8f8bf16_rowwise_impl(
auto cde_element_op = CDEElementOp{};

constexpr ck::index_t NumDTensor = ck::Number<2>{};
constexpr auto I0 =
ck::Number<0>{}; // Used to indicate 0 stride for row and col broadcast.

auto argument = gemm.MakeArgument(
reinterpret_cast<ADataType*>(XQ.data_ptr()),
Expand All @@ -190,7 +188,7 @@ at::Tensor f8f8bf16_rowwise_impl(
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
std::array<ck::index_t, NumDTensor>{0, 0},
StrideE,
a_element_op,
b_element_op,
Expand Down