Skip to content

Commit

Permalink
Support FP8 scale calculation with scalar and cleanup
Browse files Browse the repository at this point in the history
Summary:
Follow up on D57263833 to support FP8 scale calculation with scalar and merge two FP8 tensorwise GEMMs into one

Note that besides `Sm90ScalarBroadcast` in CUTLASS, AMD CK f8f8bf16 GEMM also requires passing scales as scalar instead of tensor scalar. This support is required in both NV and AMD sides

Differential Revision: D57367680
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed May 15, 2024
1 parent be752af commit 0cf43d9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 240 deletions.
204 changes: 5 additions & 199 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -594,197 +594,6 @@ KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
}

// Cutlass tensorwise kernel
template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool FAST_ACCUM>
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder

using MainLoopSchedule = cute::conditional_t<
FAST_ACCUM,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAuto,
MainLoopSchedule>::CollectiveOp;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{}));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{}));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{}));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
stride_a},
{{scale.data_ptr<float>(), 0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};
Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale,
bool use_fast_accum) {
auto M = XQ.size(0);
// auto K = XQ.size(1);
// auto N = WQ.size(0);
if (use_fast_accum) {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale);
}
} else {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale);
}
}
}

template <
int TB_M,
int TB_N,
Expand All @@ -794,7 +603,7 @@ template <
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_tensorwise_impl(
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale) {
Expand Down Expand Up @@ -978,21 +787,18 @@ at::Tensor f8f8bf16_tensorwise_impl(
return Y;
}

at::Tensor f8f8bf16_tensorwise(
at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale,
bool use_fast_accum) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
} else {
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
XQ, WQ, scale);
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale);
}
}

Expand Down
60 changes: 32 additions & 28 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale);

// Cutlass kernel
at::Tensor f8f8bf16(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_tensorwise(
at::Tensor XQ,
at::Tensor WQ,
double scale,
Expand Down Expand Up @@ -67,7 +62,12 @@ at::Tensor f8i4bf16_rowwise(
at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale);
std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(at::Tensor X);

std::vector<at::Tensor> quantize_fp8_per_tensor(
std::tuple<at::Tensor, double> quantize_fp8_per_tensor(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound
Expand Down Expand Up @@ -103,10 +103,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor");

m.def(
"f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor");

m.def(
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
"f8f8bf16(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");

m.def(
"f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor");
Expand Down Expand Up @@ -139,8 +136,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// quantize_ops with
// torch.ops.load_library
m.def(
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]");
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, float)");
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.def(
"quantize_fp8_per_tensor_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, Tensor)");
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.def(
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]");
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
Expand All @@ -166,10 +168,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#ifndef USE_ROCM
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("i8i8bf16", i8i8bf16);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
}

Expand Down Expand Up @@ -197,13 +201,22 @@ at::Tensor f8f8bf16_rowwise_meta(
return Y;
}

std::vector<at::Tensor> quantize_fp8_per_tensor_meta(
std::tuple<at::Tensor, double> quantize_fp8_per_tensor_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = 0.0;
return std::tuple<at::Tensor, double>{Y, scale};
}

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty({}, X.options().dtype(at::kBFloat16));
return {Y, scale};
return std::tuple<at::Tensor, at::Tensor>{Y, scale};
}

at::Tensor f8f8bf16_cublas_meta(
Expand All @@ -220,17 +233,6 @@ at::Tensor f8f8bf16_cublas_meta(
}

at::Tensor f8f8bf16_meta(
at::Tensor X,
at::Tensor W,
at::Tensor scale,
bool use_fast_accum = true) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8f8bf16_tensorwise_meta(
at::Tensor X,
at::Tensor W,
double scale,
Expand All @@ -255,10 +257,12 @@ at::Tensor f8i4bf16_rowwise_meta(

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("i8i8bf16", i8i8bf16_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale_meta);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta);
}
Expand Down
Loading

0 comments on commit 0cf43d9

Please sign in to comment.