Skip to content

Commit

Permalink
Add FP16, INT8, INT4, INT2 support for SSD inference TBE (#1479)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1479

Continue to add non-FP32 support in SSD inference TBE.

Differential Revision: D41540641

fbshipit-source-id: f7a3856536daab054092febee32cbe02238515b9
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 28, 2022
1 parent d2923b9 commit dab69e7
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 94 deletions.
220 changes: 216 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import numpy as np
import torch
from fbgemm_gpu.split_embedding_configs import SparseType # usort:skip
from fbgemm_gpu.split_embedding_configs import (
FP8QuantizationConfig,
SparseType,
) # usort:skip

# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
from numpy.random import default_rng
Expand All @@ -33,8 +36,8 @@ def get_device() -> torch.device:


def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, Embedding,
# EmbeddingBag]`.
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor,
# torch.nn.EmbeddingBag]`.
return t.cpu() if use_cpu else t.cuda()


Expand Down Expand Up @@ -228,8 +231,14 @@ def generate_requests(


def quantize_embs(
weight: torch.Tensor, weight_ty: SparseType
weight: torch.Tensor,
weight_ty: SparseType,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
weight.dtype == torch.float
), "The input tensor for quantize_embs function needs to be float tensor"
weight = weight.detach()
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
Expand All @@ -242,7 +251,21 @@ def quantize_embs(
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
# Quantize FP32 to HPF8
res_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
weight.float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
return (res_weight, None)

elif weight_ty == SparseType.INT8:
# Note that FloatToFused8BitRowwiseQuantized might have additional padding
# for alignment if embedding dimension is not a multiple of 4:
# https://fburl.com/code/z009xsy6
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
Expand All @@ -257,6 +280,8 @@ def quantize_embs(
return (res_weight, res_scale_shift)

elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
# Note that FP32 -> INT4/INT2 conersion op below might have additional padding
# for alignment: https://fburl.com/code/xx9kkduf
q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight,
bit_rate=weight_ty.bit_rate(),
Expand All @@ -269,3 +294,190 @@ def quantize_embs(

else:
raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))


def dequantize_embs(
weights: torch.Tensor,
scale_shift: torch.Tensor,
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
assert scale_shift is None
# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
weights,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
return to_device(comps, use_cpu)

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = np_weights.view(np.half)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = np_weights.view(np.float32)
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
return to_device(torch.tensor(comps), use_cpu)


def fake_quantize_embs(
weights: torch.Tensor,
scale_shift: Optional[torch.Tensor],
dequant_weights: torch.Tensor,
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> None:
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
assert scale_shift is None
# Quantize FP32 to HPF8
comps = torch.ops.fbgemm.FloatToHFP8Quantized(
dequant_weights.detach().float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
weights.copy_(comps)

# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
comps,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = dequant_weights.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = dequant_weights.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ def __init__(
if not weight_ty.is_float():
assert (
dim % (8 / weight_ty.bit_rate()) == 0
), "For quantized types we need to at least pack at byte granularity"
), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"

def max_ty_D(ty: SparseType) -> int:
return max(
Expand Down
99 changes: 10 additions & 89 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fbgemm_gpu.split_embedding_configs import FP8QuantizationConfig
from fbgemm_gpu.split_embedding_utils import (
b_indices,
fake_quantize_embs,
generate_requests,
get_table_batched_offsets_from_dense,
quantize_embs,
Expand Down Expand Up @@ -3280,95 +3281,15 @@ def execute_nbit_forward_(
np.stack([scales, shifts], axis=1).astype(np.float16).view(np.uint8)
)

for t in range(T):
(weights, scale_shift) = cc.split_embedding_weights()[t]
np_weights = weights.contiguous().cpu().numpy()

if scale_shift is not None:
scale_shift: np.ndarray = (
scale_shift.cpu()
.contiguous()
.numpy()
.view(np.float16)
.astype(np.float32)
)

if weights_ty_list[t] == SparseType.INT4:
(E, D_2) = np_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.INT2:
(E, D_4) = np_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = comps.reshape(E, D)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.INT8:
(E, D) = np_weights.shape
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
comps = np_weights.astype(np.float32) * scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu))

elif weights_ty_list[t] == SparseType.FP8:
# Quantize FP32 to HPF8
comps = torch.ops.fbgemm.FloatToHFP8Quantized(
bs[t].weight.detach().float(),
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
fp8_config.get("max_position"),
)
weights.copy_(comps)

# Dequantize HPF8 to FP32
comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
comps,
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
bs[t].weight.data.copy_(comps)

elif weights_ty_list[t] == SparseType.FP16:
comps = bs[t].weight.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
elif weights_ty_list[t] == SparseType.FP32:
comps = bs[t].weight.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
fake_quantize_embs(
weights,
scale_shift,
bs[t].weight.detach(),
weights_ty_list[t],
use_cpu=False,
# pyre-fixme[61]: `fp8_config` is undefined, or not always defined.
fp8_config=fp8_config if has_fp8_weight else None,
)

if not use_cpu:
fc2 = (
Expand Down

0 comments on commit dab69e7

Please sign in to comment.