Skip to content

Commit

Permalink
Refactor based on recent torch tensor view support (#1816)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1816

torch.view recently added reinterpret_cast dtype support:

https://pytorch.org/docs/stable/generated/torch.Tensor.view.html

Previously we have to rely on numpy, which introduces the extra H2D/D2H overhead.

Reviewed By: sryap

Differential Revision: D46585543

fbshipit-source-id: 65bf589faa4034804f895eada4352af6e6d8a1a4
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jun 10, 2023
1 parent d1c4a6f commit 1476a20
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 136 deletions.
59 changes: 5 additions & 54 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

import logging
import math
from typing import Optional, Tuple
from typing import cast, Optional, Tuple

import numpy as np
import torch

from fbgemm_gpu.split_embedding_configs import QuantizationConfig, SparseType

from fbgemm_gpu.split_embedding_utils import FP8QuantizationConfig, quantize_embs
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
Expand Down Expand Up @@ -93,59 +95,8 @@ def _get_quantization_config(self, name):
def _quantize_embs(
self, weight: Tensor, weight_ty: SparseType
) -> Tuple[Tensor, Optional[Tensor]]:
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
# Here it uses numpy and it will introduce DtoH/HtoD overhead.
res_weight = torch.tensor(
q_weight.cpu().numpy().view(np.uint8)
).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP16:
q_weight = weight.half()
res_weight = torch.tensor(
q_weight.cpu().numpy().view(np.uint8)
).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP8:
# Output tensor is already in uint8
q_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
weight.float(),
self._get_quantization_config("exponent_bits"),
self._get_quantization_config("exponent_bias"),
self._get_quantization_config("max_position"),
)
return (q_weight, None)

elif weight_ty == SparseType.INT8:
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
q_weight[:, -8:]
.contiguous()
.cpu()
.numpy()
.view(np.float32)
.astype(np.float16)
.view(np.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight,
bit_rate=weight_ty.bit_rate(),
)
res_weight = torch.tensor(q_weight[:, :-4].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
q_weight[:, -4:].contiguous().cpu().numpy().view(np.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

else:
raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))
fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
return quantize_embs(weight, weight_ty, fp8_quant_config)

def _process_split_embs(self, model: nn.Module) -> None:
for name, child in model.named_children():
Expand Down
143 changes: 61 additions & 82 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,20 +320,15 @@ def quantize_embs(
weight_ty: SparseType,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
weight.dtype == torch.float
), "The input tensor for quantize_embs function needs to be float tensor"
weight = weight.detach()
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
# Here it uses numpy and it will introduce DtoH/HtoD overhead.
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
res_weight = q_weight.view(torch.uint8)
return (res_weight, None)

elif weight_ty == SparseType.FP16:
q_weight = weight.half()
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
res_weight = q_weight.view(torch.uint8)
return (res_weight, None)

elif weight_ty == SparseType.FP8:
Expand All @@ -352,15 +347,9 @@ def quantize_embs(
# for alignment if embedding dimension is not a multiple of 4:
# https://fburl.com/code/z009xsy6
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_weight = q_weight[:, :-8].view(torch.uint8)
res_scale_shift = torch.tensor(
q_weight[:, -8:]
.contiguous()
.cpu()
.numpy()
.view(np.float32)
.astype(np.float16)
.view(np.uint8)
q_weight[:, -8:].view(torch.float32).to(torch.float16).view(torch.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

Expand All @@ -371,9 +360,9 @@ def quantize_embs(
weight,
bit_rate=weight_ty.bit_rate(),
)
res_weight = torch.tensor(q_weight[:, :-4].cpu().numpy().view(np.uint8))
res_weight = q_weight[:, :-4].view(torch.uint8)
res_scale_shift = torch.tensor(
q_weight[:, -4:].contiguous().cpu().numpy().view(np.uint8)
q_weight[:, -4:].view(torch.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

Expand All @@ -392,60 +381,54 @@ def dequantize_embs(
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()
th_weights = weights

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)
th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
(E, D_2) = th_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
(E, D_4) = th_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
(E, D) = th_weights.shape
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
torch.float32
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP8:
Expand All @@ -461,12 +444,12 @@ def comp(i: int) -> np.ndarray:

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = np_weights.view(np.half)
comps = th_weights.view(torch.half)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = np_weights.view(np.float32)
comps = th_weights.view(torch.float32)
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
return to_device(torch.tensor(comps), use_cpu)

Expand All @@ -482,61 +465,57 @@ def fake_quantize_embs(
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()
th_weights = weights

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
th_scale_shift: torch.Tensor = (
scale_shift.contiguous().view(torch.float16).to(torch.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
(E, D_2) = th_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
(E, D_4) = th_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
(E, D) = th_weights.shape
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
torch.float32
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
Expand All @@ -556,13 +535,13 @@ def comp(i: int) -> np.ndarray:
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = dequant_weights.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
comps = dequant_weights.detach().half().view(torch.uint8)
weights.copy_(comps)
elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = dequant_weights.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
comps = dequant_weights.detach().float().view(torch.uint8)
weights.copy_(comps)

0 comments on commit 1476a20

Please sign in to comment.