Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors #1763

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_rowwise_problem(m: int, n: int, k: int, device):
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=device)
A_scale = torch.randn((m,), dtype=torch.half, device=device)
B = torch.randint(-128, 127, size=(n, 4 * k // 8), dtype=torch.int8, device=device)
B_scale = torch.randn((n,), dtype=torch.half, device=device)
C = None

return A, A_scale, B, B_scale, C


def get_blockwise_problem(
m: int, n: int, k: int, block_size: int, dtype: torch.dtype, device
):
assert (
n % block_size == 0 and k % block_size == 0
), "N and K dims must be divisible by block_size"
A = (448.0 * (2 * torch.rand(m, k, device=device) - 1)).to(dtype)
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=device)
B = (448.0 * (2 * torch.rand(n, k, device=device) - 1)).to(dtype)
B_scale = torch.randn(
(n // block_size, k // block_size), dtype=torch.half, device=device
)

return A, A_scale, B, B_scale


def benchmark_latency(
m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device
):
A_ref = torch.randn((m, k), dtype=torch.half, device=device)
B_ref = torch.randn((n, k), dtype=torch.half, device=device)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k, device)
rowwise_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size, dtype, device)
blockwise_time = benchmark_microseconds(blockwise_fp8_gemm, A, A_scale, B, B_scale)

return {
"m": m,
"k": k,
"n": n,
"block_size": block_size,
"dtype": dtype,
"fp16_latency (ms)": fp16_time,
"rowwise_latency (ms)": rowwise_time,
"blockwise_latency (ms)": blockwise_time,
"rowwise_speedup": fp16_time / rowwise_time,
"blockwise_speedup": fp16_time / blockwise_time,
}


def benchmark_precision(
m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device
):
lin = torch.nn.Linear(k, n, False, device, torch.half)
A = torch.randn((m, k), dtype=torch.half, device=device)
W = lin.weight
output = A @ W.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size, dtype)
W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

quantize_(lin, int8_dynamic_activation_int4_weight())
output_rowwise = lin(A)

return {
"m": m,
"k": k,
"n": n,
"block_size": block_size,
"dtype": dtype,
"error_rowwise (dB)": compute_error(output, output_rowwise),
"error_blockwise (dB)": compute_error(output, output_blockwise),
}


def get_device_available_dtypes():
sm = torch.cuda.get_device_capability()
available_dtypes = []

if sm[0] == 8 and sm[1] == 0: # A100
available_dtypes.append(torch.float8_e5m2)
elif sm[0] == 9 and sm[1] == 0: # H100
available_dtypes.append(torch.float8_e5m2)
elif sm[0] == 8 and sm[1] == 9: # L4
available_dtypes.append(torch.float8_e4m3fn)
available_dtypes.append(torch.float8_e5m2)

print(
f"Available data types for device with compute capability {sm}: {available_dtypes}"
)
return available_dtypes


if __name__ == "__main__":
device = torch.device("cuda")
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
block_size_vals = (128, 128, 128, 128)

latency_results = []
precision_results = []

available_dtypes = get_device_available_dtypes()

for m in tqdm([1 << i for i in range(10)]):
for dtype in available_dtypes:
for n, k, block_size in zip(n_vals, k_vals, block_size_vals):
latency_results.append(
benchmark_latency(m, k, n, block_size, dtype, device)
)
precision_results.append(
benchmark_precision(m, k, n, block_size, dtype, device)
)

df_latency = pd.DataFrame(latency_results)
df_precision = pd.DataFrame(precision_results)

df_latency.to_csv("blockwise_triton_latency_results.csv", index=False)
df_precision.to_csv("blockwise_triton_precision_results.csv", index=False)

print(df_latency.to_markdown(index=False))
print(df_precision.to_markdown(index=False))
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tabulate # QOL for printing tables to stdout
tiktoken
blobfile
lm_eval
triton
# sam
diskcache
pycocotools
Expand Down
49 changes: 49 additions & 0 deletions test/prototype/test_blockwise_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import torch

from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_dequant,
fp8_blockwise_weight_quant,
)

BLOCKWISE_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
def test_blockwise_quant_dequant(_, N, K):
x = torch.randn(N, K).cuda()
qx, s = fp8_blockwise_weight_quant(x)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
error = torch.norm(x - x_reconstructed) / torch.norm(x)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quant-Dequant error is too high"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
def test_blockwise_fp8_gemm(M, N, K):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()

C = A @ B.T

A_q, A_s = fp8_blockwise_act_quant(A)
B_q, B_s = fp8_blockwise_weight_quant(B)

C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
print(C_q, C)
error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quantize gemm error is too high"
31 changes: 31 additions & 0 deletions torchao/prototype/blockwise_fp8/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Blockwise Quantization Implementation

## Overview

This directory contains the implementation of blockwise quantization introduced by DeepSeek. The method involves quantizing activations and weight matrices in blocks of 128x1 and 128x128, respectively. This approach aims to optimize the efficiency and performance of neural network computations.

## Quantization Process

### Activation Quantization
- Activations are quantized in blocks of size 128x1.
- This blockwise approach helps in reducing the memory footprint and computational load.

### Weight Matrix Quantization
- Weight matrices are quantized in blocks of size 128x128.
- The weights are quantized using the FP8 format, which balances precision and performance.

## Kernel Implementation in Triton

The kernel for blockwise quantization is implemented using Triton, a language designed for writing efficient GPU code. The Triton kernel handles the quantization process, ensuring that the operations are optimized for performance on modern GPUs.

## Illustration

![Blockwise Quantization Illustration](https://arxiv.org/html/2412.19437v1/x7.png)

*Illustration of the blockwise quantization process.*

## Original Paper

For detailed motivations and technical specifications, please refer to the original paper:
- [DeepSeek Blockwise Quantization Paper](https://arxiv.org/html/2412.19437v1)

15 changes: 15 additions & 0 deletions torchao/prototype/blockwise_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from .blockwise_linear import BlockwiseQuantLinear
from .blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
fp8_blockwise_weight_dequant,
)

__all__ = [
"blockwise_fp8_gemm",
"BlockwiseQuantLinear",
"fp8_blockwise_act_quant",
"fp8_blockwise_weight_quant",
"fp8_blockwise_weight_dequant",
]
80 changes: 80 additions & 0 deletions torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import triton
import triton.language as tl
from triton import Config

# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

fp8_gemm_configs = [
Config(
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]


@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit
def blockwise_fp8_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
a_s_ptr,
b_s_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1

c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)


def blockwise_fp8_gemm(
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
):
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
blockwise_fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
Loading