Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP16, INT8, INT4, INT2 support for SSD inference TBE #1479

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 216 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import numpy as np
import torch
from fbgemm_gpu.split_embedding_configs import SparseType # usort:skip
from fbgemm_gpu.split_embedding_configs import (
FP8QuantizationConfig,
SparseType,
) # usort:skip

# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
from numpy.random import default_rng
Expand All @@ -33,8 +36,8 @@ def get_device() -> torch.device:


def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, Embedding,
# EmbeddingBag]`.
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor,
# torch.nn.EmbeddingBag]`.
return t.cpu() if use_cpu else t.cuda()


Expand Down Expand Up @@ -228,8 +231,14 @@ def generate_requests(


def quantize_embs(
weight: torch.Tensor, weight_ty: SparseType
weight: torch.Tensor,
weight_ty: SparseType,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
weight.dtype == torch.float
), "The input tensor for quantize_embs function needs to be float tensor"
weight = weight.detach()
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
Expand All @@ -242,7 +251,21 @@ def quantize_embs(
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
# Quantize FP32 to HPF8
res_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
weight.float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
return (res_weight, None)

elif weight_ty == SparseType.INT8:
# Note that FloatToFused8BitRowwiseQuantized might have additional padding
# for alignment if embedding dimension is not a multiple of 4:
# https://fburl.com/code/z009xsy6
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
Expand All @@ -257,6 +280,8 @@ def quantize_embs(
return (res_weight, res_scale_shift)

elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
# Note that FP32 -> INT4/INT2 conersion op below might have additional padding
# for alignment: https://fburl.com/code/xx9kkduf
q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight,
bit_rate=weight_ty.bit_rate(),
Expand All @@ -269,3 +294,190 @@ def quantize_embs(

else:
raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))


def dequantize_embs(
weights: torch.Tensor,
scale_shift: torch.Tensor,
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
assert scale_shift is None
# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
weights,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
return to_device(comps, use_cpu)

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = np_weights.view(np.half)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = np_weights.view(np.float32)
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
return to_device(torch.tensor(comps), use_cpu)


def fake_quantize_embs(
weights: torch.Tensor,
scale_shift: Optional[torch.Tensor],
dequant_weights: torch.Tensor,
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> None:
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
assert scale_shift is None
# Quantize FP32 to HPF8
comps = torch.ops.fbgemm.FloatToHFP8Quantized(
dequant_weights.detach().float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
weights.copy_(comps)

# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
comps,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = dequant_weights.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = dequant_weights.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ def __init__(
if not weight_ty.is_float():
assert (
dim % (8 / weight_ty.bit_rate()) == 0
), "For quantized types we need to at least pack at byte granularity"
), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"

def max_ty_D(ty: SparseType) -> int:
return max(
Expand Down
99 changes: 10 additions & 89 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fbgemm_gpu.split_embedding_configs import FP8QuantizationConfig
from fbgemm_gpu.split_embedding_utils import (
b_indices,
fake_quantize_embs,
generate_requests,
get_table_batched_offsets_from_dense,
quantize_embs,
Expand Down Expand Up @@ -3280,95 +3281,15 @@ def execute_nbit_forward_(
np.stack([scales, shifts], axis=1).astype(np.float16).view(np.uint8)
)

for t in range(T):
(weights, scale_shift) = cc.split_embedding_weights()[t]
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
scale_shift: np.ndarray = (
scale_shift.cpu()
.contiguous()
.numpy()
.view(np.float16)
.astype(np.float32)
)

if weights_ty_list[t] == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.INT8:
(E, D) = np_weights.shape
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
comps = np_weights.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.FP8:
# Quantize FP32 to HPF8
comps = torch.ops.fbgemm.FloatToHFP8Quantized(
bs[t].weight.detach().float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
weights.copy_(comps)

# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
comps,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
bs[t].weight.data.copy_(comps)

elif weights_ty_list[t] == SparseType.FP16:
comps = bs[t].weight.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
elif weights_ty_list[t] == SparseType.FP32:
comps = bs[t].weight.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
fake_quantize_embs(
weights,
scale_shift,
bs[t].weight.detach(),
weights_ty_list[t],
use_cpu=False,
# pyre-fixme[61]: `fp8_config` is undefined, or not always defined.
fp8_config=fp8_config if has_fp8_weight else None,
)

if not use_cpu:
fc2 = (
Expand Down