diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py index 9e64b1b44b..0691a69128 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py @@ -8,7 +8,10 @@ import numpy as np import torch -from fbgemm_gpu.split_embedding_configs import SparseType # usort:skip +from fbgemm_gpu.split_embedding_configs import ( + FP8QuantizationConfig, + SparseType, +) # usort:skip # pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). from numpy.random import default_rng @@ -33,8 +36,8 @@ def get_device() -> torch.device: def to_device(t: Deviceable, use_cpu: bool) -> Deviceable: - # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, Embedding, - # EmbeddingBag]`. + # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, + # torch.nn.EmbeddingBag]`. return t.cpu() if use_cpu else t.cuda() @@ -228,8 +231,14 @@ def generate_requests( def quantize_embs( - weight: torch.Tensor, weight_ty: SparseType + weight: torch.Tensor, + weight_ty: SparseType, + fp8_config: Optional[FP8QuantizationConfig] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert ( + weight.dtype == torch.float + ), "The input tensor for quantize_embs function needs to be float tensor" + weight = weight.detach() if weight_ty == SparseType.FP32: q_weight = weight.float() # FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8) @@ -242,7 +251,21 @@ def quantize_embs( res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous() return (res_weight, None) + elif weight_ty == SparseType.FP8: + assert fp8_config is not None + # Quantize FP32 to HPF8 + res_weight = torch.ops.fbgemm.FloatToHFP8Quantized( + weight.float(), + fp8_config.get("exponent_bits"), + fp8_config.get("exponent_bias"), + fp8_config.get("max_position"), + ) + return (res_weight, None) + elif weight_ty == SparseType.INT8: + # Note that FloatToFused8BitRowwiseQuantized might have additional padding + # for alignment if embedding dimension is not a multiple of 4: + # https://fburl.com/code/z009xsy6 q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight) res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8)) res_scale_shift = torch.tensor( @@ -257,6 +280,8 @@ def quantize_embs( return (res_weight, res_scale_shift) elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2: + # Note that FP32 -> INT4/INT2 conersion op below might have additional padding + # for alignment: https://fburl.com/code/xx9kkduf q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf( weight, bit_rate=weight_ty.bit_rate(), @@ -269,3 +294,190 @@ def quantize_embs( else: raise RuntimeError("Unsupported SparseType: {}".format(weight_ty)) + + +def dequantize_embs( + weights: torch.Tensor, + scale_shift: torch.Tensor, + weight_ty: SparseType, + use_cpu: bool, + fp8_config: Optional[FP8QuantizationConfig] = None, +) -> torch.Tensor: + print(f"weight_ty: {weight_ty}") + assert ( + weights.dtype == torch.uint8 + ), "The input tensor for dequantize_embs function needs to be byte tensor" + np_weights = weights.contiguous().cpu().numpy() + + if scale_shift is not None: + np_scale_shift: np.ndarray = ( + scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32) + ) + + if weight_ty == SparseType.INT4: + (E, D_2) = np_weights.shape + D = D_2 * 2 + + def comp(i: int) -> np.ndarray: + subs = np_weights.view(np.uint8) >> (i * 4) + sub_mask = subs & 0xF + result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype( + np.float32 + ) + return result.astype(np.float32) + + comps = [comp(i) for i in range(2)] + comps = np.stack(comps) + comps = comps.transpose(1, 2, 0) + comps = comps.reshape(E, D) + return to_device(torch.tensor(comps), use_cpu) + + elif weight_ty == SparseType.INT2: + (E, D_4) = np_weights.shape + D = D_4 * 4 + + # pyre-fixme[53]: Captured variable `scale_shift` is not annotated. + # pyre-fixme[53]: Captured variable `weights` is not annotated. + def comp(i: int) -> np.ndarray: + subs = np_weights.view(np.uint8) >> (i * 2) + sub_mask = subs & 0x3 + result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype( + np.float32 + ) + return result.astype(np.float32) + + comps = [comp(i) for i in range(4)] + comps = np.stack(comps) + comps = comps.transpose(1, 2, 0) + comps = comps.reshape(E, D) + return to_device(torch.tensor(comps), use_cpu) + + elif weight_ty == SparseType.INT8: + (E, D) = np_weights.shape + comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32) + return to_device(torch.tensor(comps), use_cpu) + + elif weight_ty == SparseType.FP8: + assert fp8_config is not None + assert scale_shift is None + # Dequantize HPF8 to FP32 + comps = torch.ops.fbgemm.HFP8QuantizedToFloat( + weights, + fp8_config.get("exponent_bits"), + fp8_config.get("exponent_bias"), + ) + return to_device(comps, use_cpu) + + elif weight_ty == SparseType.FP16: + assert scale_shift is None + comps = np_weights.view(np.half) + return to_device(torch.tensor(comps), use_cpu) + + elif weight_ty == SparseType.FP32: + assert scale_shift is None + comps = np_weights.view(np.float32) + # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`. + return to_device(torch.tensor(comps), use_cpu) + + +def fake_quantize_embs( + weights: torch.Tensor, + scale_shift: Optional[torch.Tensor], + dequant_weights: torch.Tensor, + weight_ty: SparseType, + use_cpu: bool, + fp8_config: Optional[FP8QuantizationConfig] = None, +) -> None: + assert ( + weights.dtype == torch.uint8 + ), "The input tensor for dequantize_embs function needs to be byte tensor" + np_weights = weights.contiguous().cpu().numpy() + + if scale_shift is not None: + np_scale_shift: np.ndarray = ( + scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32) + ) + + if weight_ty == SparseType.INT4: + (E, D_2) = np_weights.shape + D = D_2 * 2 + + def comp(i: int) -> np.ndarray: + subs = np_weights.view(np.uint8) >> (i * 4) + sub_mask = subs & 0xF + result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype( + np.float32 + ) + return result.astype(np.float32) + + comps = [comp(i) for i in range(2)] + comps = np.stack(comps) + comps = comps.transpose(1, 2, 0) + comps = comps.reshape(E, D) + dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu)) + + elif weight_ty == SparseType.INT2: + (E, D_4) = np_weights.shape + D = D_4 * 4 + + # pyre-fixme[53]: Captured variable `scale_shift` is not annotated. + # pyre-fixme[53]: Captured variable `weights` is not annotated. + def comp(i: int) -> np.ndarray: + subs = np_weights.view(np.uint8) >> (i * 2) + sub_mask = subs & 0x3 + result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype( + np.float32 + ) + return result.astype(np.float32) + + comps = [comp(i) for i in range(4)] + comps = np.stack(comps) + comps = comps.transpose(1, 2, 0) + comps = comps.reshape(E, D) + dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu)) + + elif weight_ty == SparseType.INT8: + (E, D) = np_weights.shape + comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape( + -1, 1 + ).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32) + dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu)) + + elif weight_ty == SparseType.FP8: + assert fp8_config is not None + assert scale_shift is None + # Quantize FP32 to HPF8 + comps = torch.ops.fbgemm.FloatToHFP8Quantized( + dequant_weights.detach().float(), + fp8_config.get("exponent_bits"), + fp8_config.get("exponent_bias"), + fp8_config.get("max_position"), + ) + weights.copy_(comps) + + # Dequantize HPF8 to FP32 + comps = torch.ops.fbgemm.HFP8QuantizedToFloat( + comps, + fp8_config.get("exponent_bits"), + fp8_config.get("exponent_bias"), + ) + dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu)) + + elif weight_ty == SparseType.FP16: + assert scale_shift is None + comps = dequant_weights.detach().half().cpu().numpy().view(np.uint8) + weights.copy_(torch.tensor(comps)) + elif weight_ty == SparseType.FP32: + assert scale_shift is None + comps = dequant_weights.detach().float().cpu().numpy().view(np.uint8) + weights.copy_(torch.tensor(comps)) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 7b51266eba..f8c95306da 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1731,7 +1731,7 @@ def __init__( if not weight_ty.is_float(): assert ( dim % (8 / weight_ty.bit_rate()) == 0 - ), "For quantized types we need to at least pack at byte granularity" + ), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}" def max_ty_D(ty: SparseType) -> int: return max( diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 1d66a7d749..3383ffc078 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -22,6 +22,7 @@ from fbgemm_gpu.split_embedding_configs import FP8QuantizationConfig from fbgemm_gpu.split_embedding_utils import ( b_indices, + fake_quantize_embs, generate_requests, get_table_batched_offsets_from_dense, quantize_embs, @@ -3280,95 +3281,15 @@ def execute_nbit_forward_( np.stack([scales, shifts], axis=1).astype(np.float16).view(np.uint8) ) - for t in range(T): - (weights, scale_shift) = cc.split_embedding_weights()[t] - np_weights = weights.contiguous().cpu().numpy() - - if scale_shift is not None: - scale_shift: np.ndarray = ( - scale_shift.cpu() - .contiguous() - .numpy() - .view(np.float16) - .astype(np.float32) - ) - - if weights_ty_list[t] == SparseType.INT4: - (E, D_2) = np_weights.shape - D = D_2 * 2 - - def comp(i: int) -> np.ndarray: - subs = np_weights.view(np.uint8) >> (i * 4) - sub_mask = subs & 0xF - result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape( - -1, 1 - ).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype( - np.float32 - ) - return result.astype(np.float32) - - comps = [comp(i) for i in range(2)] - comps = np.stack(comps) - comps = comps.transpose(1, 2, 0) - comps = comps.reshape(E, D) - bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu)) - - elif weights_ty_list[t] == SparseType.INT2: - (E, D_4) = np_weights.shape - D = D_4 * 4 - - # pyre-fixme[53]: Captured variable `scale_shift` is not annotated. - # pyre-fixme[53]: Captured variable `weights` is not annotated. - def comp(i: int) -> np.ndarray: - subs = np_weights.view(np.uint8) >> (i * 2) - sub_mask = subs & 0x3 - result = sub_mask.astype(np.float32) * scale_shift[:, 0].reshape( - -1, 1 - ).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype( - np.float32 - ) - return result.astype(np.float32) - - comps = [comp(i) for i in range(4)] - comps = np.stack(comps) - comps = comps.transpose(1, 2, 0) - comps = comps.reshape(E, D) - bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu)) - - elif weights_ty_list[t] == SparseType.INT8: - (E, D) = np_weights.shape - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - comps = np_weights.astype(np.float32) * scale_shift[:, 0].reshape( - -1, 1 - ).astype(np.float32) + scale_shift[:, 1].reshape(-1, 1).astype( - np.float32 - ) - bs[t].weight.detach().copy_(to_device(torch.tensor(comps), use_cpu)) - - elif weights_ty_list[t] == SparseType.FP8: - # Quantize FP32 to HPF8 - comps = torch.ops.fbgemm.FloatToHFP8Quantized( - bs[t].weight.detach().float(), - fp8_config.get("exponent_bits"), - fp8_config.get("exponent_bias"), - fp8_config.get("max_position"), - ) - weights.copy_(comps) - - # Dequantize HPF8 to FP32 - comps = torch.ops.fbgemm.HFP8QuantizedToFloat( - comps, - fp8_config.get("exponent_bits"), - fp8_config.get("exponent_bias"), - ) - bs[t].weight.data.copy_(comps) - - elif weights_ty_list[t] == SparseType.FP16: - comps = bs[t].weight.detach().half().cpu().numpy().view(np.uint8) - weights.copy_(torch.tensor(comps)) - elif weights_ty_list[t] == SparseType.FP32: - comps = bs[t].weight.detach().float().cpu().numpy().view(np.uint8) - weights.copy_(torch.tensor(comps)) + fake_quantize_embs( + weights, + scale_shift, + bs[t].weight.detach(), + weights_ty_list[t], + use_cpu=False, + # pyre-fixme[61]: `fp8_config` is undefined, or not always defined. + fp8_config=fp8_config if has_fp8_weight else None, + ) if not use_cpu: fc2 = (