Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 scale calculation with scalar and cleanup #2593

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 5 additions & 199 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -594,197 +594,6 @@ KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
}

// Cutlass tensorwise kernel
template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool FAST_ACCUM>
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());

auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::float_e4m3_t;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = cutlass::gemm::collective::
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder

using MainLoopSchedule = cute::conditional_t<
FAST_ACCUM,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAuto,
MainLoopSchedule>::CollectiveOp;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{}));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{}));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{}));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
stride_a},
{{scale.data_ptr<float>(), 0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};
Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale,
bool use_fast_accum) {
auto M = XQ.size(0);
// auto K = XQ.size(1);
// auto N = WQ.size(0);
if (use_fast_accum) {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale);
}
} else {
if (M <= 128) {
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale);
} else {
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale);
}
}
}

template <
int TB_M,
int TB_N,
Expand All @@ -794,7 +603,7 @@ template <
int TBS_K,
bool PONG,
bool FAST_ACCUM>
at::Tensor f8f8bf16_tensorwise_impl(
at::Tensor f8f8bf16_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale) {
Expand Down Expand Up @@ -978,21 +787,18 @@ at::Tensor f8f8bf16_tensorwise_impl(
return Y;
}

at::Tensor f8f8bf16_tensorwise(
at::Tensor f8f8bf16(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
double scale,
bool use_fast_accum) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale);
} else {
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
XQ, WQ, scale);
return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale);
}
}

Expand Down
60 changes: 32 additions & 28 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale);

// Cutlass kernel
at::Tensor f8f8bf16(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_tensorwise(
at::Tensor XQ,
at::Tensor WQ,
double scale,
Expand Down Expand Up @@ -67,7 +62,12 @@ at::Tensor f8i4bf16_rowwise(
at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale);
std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(at::Tensor X);

std::vector<at::Tensor> quantize_fp8_per_tensor(
std::tuple<at::Tensor, double> quantize_fp8_per_tensor(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale(
at::Tensor input,
c10::optional<at::Tensor> bs, // batch size
c10::optional<at::Tensor> scale_ub); // scale upperbound
Expand Down Expand Up @@ -103,10 +103,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor");

m.def(
"f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor");

m.def(
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
"f8f8bf16(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");

m.def(
"f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor");
Expand Down Expand Up @@ -139,8 +136,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// quantize_ops with
// torch.ops.load_library
m.def(
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]");
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, float)");
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.def(
"quantize_fp8_per_tensor_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, Tensor)");
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.def(
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]");
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
Expand All @@ -166,10 +168,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#ifndef USE_ROCM
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("i8i8bf16", i8i8bf16);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
}

Expand Down Expand Up @@ -197,13 +201,22 @@ at::Tensor f8f8bf16_rowwise_meta(
return Y;
}

std::vector<at::Tensor> quantize_fp8_per_tensor_meta(
std::tuple<at::Tensor, double> quantize_fp8_per_tensor_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = 0.0;
return std::tuple<at::Tensor, double>{Y, scale};
}

std::tuple<at::Tensor, at::Tensor> quantize_fp8_per_tensor_tensor_scale_meta(
at::Tensor X,
c10::optional<at::Tensor> bs,
c10::optional<at::Tensor> scale_ub) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty({}, X.options().dtype(at::kBFloat16));
return {Y, scale};
return std::tuple<at::Tensor, at::Tensor>{Y, scale};
}

at::Tensor f8f8bf16_cublas_meta(
Expand All @@ -220,17 +233,6 @@ at::Tensor f8f8bf16_cublas_meta(
}

at::Tensor f8f8bf16_meta(
at::Tensor X,
at::Tensor W,
at::Tensor scale,
bool use_fast_accum = true) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8f8bf16_tensorwise_meta(
at::Tensor X,
at::Tensor W,
double scale,
Expand All @@ -255,10 +257,12 @@ at::Tensor f8i4bf16_rowwise_meta(

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("i8i8bf16", i8i8bf16_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta);
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta);
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
m.impl(
"quantize_fp8_per_tensor_tensor_scale",
quantize_fp8_per_tensor_tensor_scale_meta);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta);
}
Expand Down
Loading
Loading