Skip to content

Commit

Permalink
Add INT8 support to fbgemm qcomm lib
Browse files Browse the repository at this point in the history
Summary:
Add INT8 codec to fbgemm library. Also defined a context class to pass the row_dim and row_dim_quant parameters during coding and encoding.

For INT8, the output size is decided by row dim, which is needed by high level collective communication modules. To that purpose, this diff adds a new calc_quantized_size function and implement it based on the INT8 codec logic.

Differential Revision: D39473386

fbshipit-source-id: 127b823f7b071ad01e02363ff41382bd3dcee6a4
  • Loading branch information
yusuo authored and facebook-github-bot committed Sep 19, 2022
1 parent 728cdb8 commit 6177cdc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 10 deletions.
64 changes: 58 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,27 @@
hfp8_to_fp32,
)
from fbgemm_gpu.split_embedding_configs import SparseType
from pyre_extensions import none_throws
from torch.autograd.profiler import record_function

logger: logging.Logger = logging.getLogger()

# FP8 configurations
ebits, mbits, bias = 4, 3, 15
max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
ROW_DIM_DEFAULT = 32


class QuantizationContext:
def __init__(self, row_dim: int = ROW_DIM_DEFAULT) -> None:
self.row_dim = row_dim
self.row_dim_quant: int = -1


def _quantize_tensor(
input_tensor: torch.Tensor,
comm_precision: SparseType,
ctx: Optional[QuantizationContext] = None,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
return input_tensor
Expand All @@ -45,13 +54,24 @@ def _quantize_tensor(
return fp32_to_bf16_with_clamp(input_tensor)
elif comm_precision == SparseType.FP8:
return fp32_to_hfp8_with_clamp(input_tensor, ebits, mbits, bias)
elif comm_precision == SparseType.INT8:
ctx = none_throws(ctx)
row_dim = ctx.row_dim
input_2d = input_tensor.view((-1, row_dim)) if row_dim > 0 else input_tensor
input_2d_quant = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_2d)
row_dim_quant = input_2d_quant.shape[1]
input_quant_all2all = None
input_quant_all2all = input_2d_quant.view((-1))
ctx.row_dim_quant = row_dim_quant
return input_quant_all2all
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")


def _dequantize_tensor(
quantized_tensor: torch.Tensor,
comm_precision: SparseType,
ctx: Optional[QuantizationContext] = None,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
assert quantized_tensor.dtype == torch.float
Expand All @@ -65,6 +85,14 @@ def _dequantize_tensor(
elif comm_precision == SparseType.FP8:
assert quantized_tensor.dtype == torch.uint8
return hfp8_to_fp32(quantized_tensor, ebits, bias)
elif comm_precision == SparseType.INT8:
ctx = none_throws(ctx)
row_dim_quant = ctx.row_dim_quant
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
dequant_tensor = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(
quantized_tensor_2d
)
return dequant_tensor.view(-1)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand All @@ -75,6 +103,7 @@ def __init__(
self,
comm_precision: SparseType,
loss_scale: Optional[float] = None,
row_dim: Optional[int] = None,
) -> None:

if loss_scale is not None:
Expand All @@ -91,23 +120,46 @@ def __init__(
self._comm_precision = comm_precision
self._loss_scale = loss_scale

def encode(self, input_tensor: torch.Tensor) -> torch.Tensor:
def encode(
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
) -> torch.Tensor:
if self._loss_scale is not None:
input_tensor = self._loss_scale * input_tensor
with record_function(
f"## encoder {self._comm_precision} {self._loss_scale} ##"
):
return _quantize_tensor(input_tensor, self._comm_precision)

def decode(self, input_grad: torch.Tensor) -> torch.Tensor:
output = _quantize_tensor(
input_tensor,
self._comm_precision,
ctx,
)
return output

def decode(
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
) -> torch.Tensor:
if self._loss_scale is not None:
input_grad = input_grad / self._loss_scale
input_tensor = input_tensor / self._loss_scale
with record_function(
f"## decoder {self._comm_precision} {self._loss_scale} ##"
):
dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision)
dequantized_tensor = _dequantize_tensor(
input_tensor, self._comm_precision, ctx
)
return dequantized_tensor

def calc_quantized_size(
self, input_len: int, ctx: Optional[QuantizationContext] = None
) -> int:
if self._comm_precision == SparseType.INT8:
ctx = none_throws(ctx)
assert input_len % ctx.row_dim == 0
nrows = input_len // ctx.row_dim
ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4
return nrows * ncols
else:
return input_len

@property
def quantized_dtype(self) -> torch.dtype:
return self._comm_precision.as_dtype()
23 changes: 19 additions & 4 deletions fbgemm_gpu/test/quantize_comm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import hypothesis.strategies as st
import torch
from fbgemm_gpu.quantize_comm import QuantizedCommCodec
from fbgemm_gpu.quantize_comm import QuantizationContext, QuantizedCommCodec
from fbgemm_gpu.split_embedding_configs import SparseType
from hypothesis import assume, given, settings

Expand All @@ -28,6 +28,7 @@ class QuantizedCommCodecTest(unittest.TestCase):
(SparseType.BF16, 2.0),
(SparseType.FP8, None),
(SparseType.FP8, 3.0),
(SparseType.INT8, None),
]
),
row_size=st.integers(4, 256),
Expand All @@ -43,18 +44,32 @@ def test_quantized_comm_codec(
) -> None:

(comm_precision, loss_scale) = comm_precisions_loss_scale

if comm_precision == SparseType.FP8:
assume(col_size % 4 == 0)

torch.manual_seed(rand_seed)
shape = (row_size, col_size)

quant_codec = QuantizedCommCodec(comm_precision, loss_scale)

ctx = QuantizationContext()
if comm_precision == SparseType.INT8:
assume(row_size * col_size % ctx.row_dim == 0)

input_tensor = torch.rand(shape, requires_grad=True)

quant_tensor = quant_codec.encode(input_tensor)
output_tensor = quant_codec.decode(quant_tensor)
if comm_precision == SparseType.INT8:
input_tensor = input_tensor.view(-1)

quant_tensor = quant_codec.encode(input_tensor, ctx)

self.assertEqual(
quant_tensor.numel(),
quant_codec.calc_quantized_size(input_tensor.numel(), ctx),
)

output_tensor = quant_codec.decode(quant_tensor, ctx)
self.assertEqual(output_tensor.shape, input_tensor.shape)

rtol = 0.005
atol = 0.005
Expand Down

0 comments on commit 6177cdc

Please sign in to comment.