From 6177cdcd9620bdd5796a8fcc1b8c1376e487a4b5 Mon Sep 17 00:00:00 2001 From: Yusuo Hu Date: Mon, 19 Sep 2022 15:16:29 -0700 Subject: [PATCH] Add INT8 support to fbgemm qcomm lib Summary: Add INT8 codec to fbgemm library. Also defined a context class to pass the row_dim and row_dim_quant parameters during coding and encoding. For INT8, the output size is decided by row dim, which is needed by high level collective communication modules. To that purpose, this diff adds a new calc_quantized_size function and implement it based on the INT8 codec logic. Differential Revision: D39473386 fbshipit-source-id: 127b823f7b071ad01e02363ff41382bd3dcee6a4 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 64 +++++++++++++++++++++++--- fbgemm_gpu/test/quantize_comm_test.py | 23 +++++++-- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index d1a1cb0814..4d7a9380f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -24,6 +24,7 @@ hfp8_to_fp32, ) from fbgemm_gpu.split_embedding_configs import SparseType +from pyre_extensions import none_throws from torch.autograd.profiler import record_function logger: logging.Logger = logging.getLogger() @@ -31,11 +32,19 @@ # FP8 configurations ebits, mbits, bias = 4, 3, 15 max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits)) +ROW_DIM_DEFAULT = 32 + + +class QuantizationContext: + def __init__(self, row_dim: int = ROW_DIM_DEFAULT) -> None: + self.row_dim = row_dim + self.row_dim_quant: int = -1 def _quantize_tensor( input_tensor: torch.Tensor, comm_precision: SparseType, + ctx: Optional[QuantizationContext] = None, ) -> torch.Tensor: if comm_precision == SparseType.FP32: return input_tensor @@ -45,6 +54,16 @@ def _quantize_tensor( return fp32_to_bf16_with_clamp(input_tensor) elif comm_precision == SparseType.FP8: return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias) + elif comm_precision == SparseType.INT8: + ctx = none_throws(ctx) + row_dim = ctx.row_dim + input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor + input_2d_quant = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_2d) + row_dim_quant = input_2d_quant.shape[1] + input_quant_all2all = None + input_quant_all2all = input_2d_quant.view((-1)) + ctx.row_dim_quant = row_dim_quant + return input_quant_all2all else: raise ValueError(f"comm_precision={comm_precision} is not supported") @@ -52,6 +71,7 @@ def _quantize_tensor( def _dequantize_tensor( quantized_tensor: torch.Tensor, comm_precision: SparseType, + ctx: Optional[QuantizationContext] = None, ) -> torch.Tensor: if comm_precision == SparseType.FP32: assert quantized_tensor.dtype == torch.float @@ -65,6 +85,14 @@ def _dequantize_tensor( elif comm_precision == SparseType.FP8: assert quantized_tensor.dtype == torch.uint8 return hfp8_to_fp32(quantized_tensor, ebits, bias) + elif comm_precision == SparseType.INT8: + ctx = none_throws(ctx) + row_dim_quant = ctx.row_dim_quant + quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant)) + dequant_tensor = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( + quantized_tensor_2d + ) + return dequant_tensor.view(-1) else: raise ValueError(f"comm_precision={comm_precision} is not supported") @@ -75,6 +103,7 @@ def __init__( self, comm_precision: SparseType, loss_scale: Optional[float] = None, + row_dim: Optional[int] = None, ) -> None: if loss_scale is not None: @@ -91,23 +120,46 @@ def __init__( self._comm_precision = comm_precision self._loss_scale = loss_scale - def encode(self, input_tensor: torch.Tensor) -> torch.Tensor: + def encode( + self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None + ) -> torch.Tensor: if self._loss_scale is not None: input_tensor = self._loss_scale * input_tensor with record_function( f"## encoder {self._comm_precision} {self._loss_scale} ##" ): - return _quantize_tensor(input_tensor, self._comm_precision) - - def decode(self, input_grad: torch.Tensor) -> torch.Tensor: + output = _quantize_tensor( + input_tensor, + self._comm_precision, + ctx, + ) + return output + + def decode( + self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None + ) -> torch.Tensor: if self._loss_scale is not None: - input_grad = input_grad / self._loss_scale + input_tensor = input_tensor / self._loss_scale with record_function( f"## decoder {self._comm_precision} {self._loss_scale} ##" ): - dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision) + dequantized_tensor = _dequantize_tensor( + input_tensor, self._comm_precision, ctx + ) return dequantized_tensor + def calc_quantized_size( + self, input_len: int, ctx: Optional[QuantizationContext] = None + ) -> int: + if self._comm_precision == SparseType.INT8: + ctx = none_throws(ctx) + assert input_len % ctx.row_dim == 0 + nrows = input_len // ctx.row_dim + ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4 + return nrows * ncols + else: + return input_len + @property def quantized_dtype(self) -> torch.dtype: return self._comm_precision.as_dtype() diff --git a/fbgemm_gpu/test/quantize_comm_test.py b/fbgemm_gpu/test/quantize_comm_test.py index df024d645e..ea44850330 100644 --- a/fbgemm_gpu/test/quantize_comm_test.py +++ b/fbgemm_gpu/test/quantize_comm_test.py @@ -10,7 +10,7 @@ import hypothesis.strategies as st import torch -from fbgemm_gpu.quantize_comm import QuantizedCommCodec +from fbgemm_gpu.quantize_comm import QuantizationContext, QuantizedCommCodec from fbgemm_gpu.split_embedding_configs import SparseType from hypothesis import assume, given, settings @@ -28,6 +28,7 @@ class QuantizedCommCodecTest(unittest.TestCase): (SparseType.BF16, 2.0), (SparseType.FP8, None), (SparseType.FP8, 3.0), + (SparseType.INT8, None), ] ), row_size=st.integers(4, 256), @@ -43,18 +44,32 @@ def test_quantized_comm_codec( ) -> None: (comm_precision, loss_scale) = comm_precisions_loss_scale + if comm_precision == SparseType.FP8: assume(col_size % 4 == 0) torch.manual_seed(rand_seed) shape = (row_size, col_size) - quant_codec = QuantizedCommCodec(comm_precision, loss_scale) + ctx = QuantizationContext() + if comm_precision == SparseType.INT8: + assume(row_size * col_size % ctx.row_dim == 0) + input_tensor = torch.rand(shape, requires_grad=True) - quant_tensor = quant_codec.encode(input_tensor) - output_tensor = quant_codec.decode(quant_tensor) + if comm_precision == SparseType.INT8: + input_tensor = input_tensor.view(-1) + + quant_tensor = quant_codec.encode(input_tensor, ctx) + + self.assertEqual( + quant_tensor.numel(), + quant_codec.calc_quantized_size(input_tensor.numel(), ctx), + ) + + output_tensor = quant_codec.decode(quant_tensor, ctx) + self.assertEqual(output_tensor.shape, input_tensor.shape) rtol = 0.005 atol = 0.005