Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor based on recent torch tensor view support #1816

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 5 additions & 54 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

import logging
import math
from typing import Optional, Tuple
from typing import cast, Optional, Tuple

import numpy as np
import torch

from fbgemm_gpu.split_embedding_configs import QuantizationConfig, SparseType

from fbgemm_gpu.split_embedding_utils import FP8QuantizationConfig, quantize_embs
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
Expand Down Expand Up @@ -93,59 +95,8 @@ def _get_quantization_config(self, name):
def _quantize_embs(
self, weight: Tensor, weight_ty: SparseType
) -> Tuple[Tensor, Optional[Tensor]]:
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
# Here it uses numpy and it will introduce DtoH/HtoD overhead.
res_weight = torch.tensor(
q_weight.cpu().numpy().view(np.uint8)
).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP16:
q_weight = weight.half()
res_weight = torch.tensor(
q_weight.cpu().numpy().view(np.uint8)
).contiguous()
return (res_weight, None)

elif weight_ty == SparseType.FP8:
# Output tensor is already in uint8
q_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
weight.float(),
self._get_quantization_config("exponent_bits"),
self._get_quantization_config("exponent_bias"),
self._get_quantization_config("max_position"),
)
return (q_weight, None)

elif weight_ty == SparseType.INT8:
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
q_weight[:, -8:]
.contiguous()
.cpu()
.numpy()
.view(np.float32)
.astype(np.float16)
.view(np.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight,
bit_rate=weight_ty.bit_rate(),
)
res_weight = torch.tensor(q_weight[:, :-4].cpu().numpy().view(np.uint8))
res_scale_shift = torch.tensor(
q_weight[:, -4:].contiguous().cpu().numpy().view(np.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

else:
raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))
fp8_quant_config = cast(FP8QuantizationConfig, self.quantization_config)
return quantize_embs(weight, weight_ty, fp8_quant_config)

def _process_split_embs(self, model: nn.Module) -> None:
for name, child in model.named_children():
Expand Down
143 changes: 61 additions & 82 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,20 +320,15 @@ def quantize_embs(
weight_ty: SparseType,
fp8_config: Optional[FP8QuantizationConfig] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
weight.dtype == torch.float
), "The input tensor for quantize_embs function needs to be float tensor"
weight = weight.detach()
if weight_ty == SparseType.FP32:
q_weight = weight.float()
# FIXME: How to view the PyTorch Tensor as a different type (e.g., uint8)
# Here it uses numpy and it will introduce DtoH/HtoD overhead.
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
res_weight = q_weight.view(torch.uint8)
return (res_weight, None)

elif weight_ty == SparseType.FP16:
q_weight = weight.half()
res_weight = torch.tensor(q_weight.cpu().numpy().view(np.uint8)).contiguous()
res_weight = q_weight.view(torch.uint8)
return (res_weight, None)

elif weight_ty == SparseType.FP8:
Expand All @@ -352,15 +347,9 @@ def quantize_embs(
# for alignment if embedding dimension is not a multiple of 4:
# https://fburl.com/code/z009xsy6
q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
res_weight = torch.tensor(q_weight[:, :-8].cpu().numpy().view(np.uint8))
res_weight = q_weight[:, :-8].view(torch.uint8)
res_scale_shift = torch.tensor(
q_weight[:, -8:]
.contiguous()
.cpu()
.numpy()
.view(np.float32)
.astype(np.float16)
.view(np.uint8)
q_weight[:, -8:].view(torch.float32).to(torch.float16).view(torch.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

Expand All @@ -371,9 +360,9 @@ def quantize_embs(
weight,
bit_rate=weight_ty.bit_rate(),
)
res_weight = torch.tensor(q_weight[:, :-4].cpu().numpy().view(np.uint8))
res_weight = q_weight[:, :-4].view(torch.uint8)
res_scale_shift = torch.tensor(
q_weight[:, -4:].contiguous().cpu().numpy().view(np.uint8)
q_weight[:, -4:].view(torch.uint8)
) # [-4, -2]: scale; [-2:]: bias
return (res_weight, res_scale_shift)

Expand All @@ -392,60 +381,54 @@ def dequantize_embs(
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()
th_weights = weights

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
)
th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
(E, D_2) = th_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
(E, D_4) = th_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
(E, D) = th_weights.shape
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
torch.float32
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP8:
Expand All @@ -461,12 +444,12 @@ def comp(i: int) -> np.ndarray:

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = np_weights.view(np.half)
comps = th_weights.view(torch.half)
return to_device(torch.tensor(comps), use_cpu)

elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = np_weights.view(np.float32)
comps = th_weights.view(torch.float32)
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
return to_device(torch.tensor(comps), use_cpu)

Expand All @@ -482,61 +465,57 @@ def fake_quantize_embs(
assert (
weights.dtype == torch.uint8
), "The input tensor for dequantize_embs function needs to be byte tensor"
np_weights = weights.contiguous().cpu().numpy()
th_weights = weights

if scale_shift is not None:
np_scale_shift: np.ndarray = (
scale_shift.cpu().contiguous().numpy().view(np.float16).astype(np.float32)
th_scale_shift: torch.Tensor = (
scale_shift.contiguous().view(torch.float16).to(torch.float32)
)

if weight_ty == SparseType.INT4:
(E, D_2) = np_weights.shape
(E, D_2) = th_weights.shape
D = D_2 * 2

def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 4)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 4)
sub_mask = subs & 0xF
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(2)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.INT2:
(E, D_4) = np_weights.shape
(E, D_4) = th_weights.shape
D = D_4 * 4

# pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
# pyre-fixme[53]: Captured variable `weights` is not annotated.
def comp(i: int) -> np.ndarray:
subs = np_weights.view(np.uint8) >> (i * 2)
def comp(i: int) -> torch.Tensor:
subs = th_weights.view(torch.uint8) >> (i * 2)
sub_mask = subs & 0x3
result = sub_mask.astype(np.float32) * np_scale_shift[:, 0].reshape(
result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(
np.float32
)
return result.astype(np.float32)
).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
return result.to(torch.float32)

comps = [comp(i) for i in range(4)]
comps = np.stack(comps)
comps = comps.transpose(1, 2, 0)
comps = torch.stack(comps)
comps = comps.permute(1, 2, 0)
comps = comps.reshape(E, D)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.INT8:
(E, D) = np_weights.shape
comps = np_weights.astype(np.float32) * np_scale_shift[:, 0].reshape(
-1, 1
).astype(np.float32) + np_scale_shift[:, 1].reshape(-1, 1).astype(np.float32)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
(E, D) = th_weights.shape
comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
torch.float32
) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.FP8:
assert fp8_config is not None
Expand All @@ -556,13 +535,13 @@ def comp(i: int) -> np.ndarray:
fp8_config.get("exponent_bits"),
fp8_config.get("exponent_bias"),
)
dequant_weights.copy_(to_device(torch.tensor(comps), use_cpu))
dequant_weights.copy_(to_device(comps, use_cpu))

elif weight_ty == SparseType.FP16:
assert scale_shift is None
comps = dequant_weights.detach().half().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
comps = dequant_weights.detach().half().view(torch.uint8)
weights.copy_(comps)
elif weight_ty == SparseType.FP32:
assert scale_shift is None
comps = dequant_weights.detach().float().cpu().numpy().view(np.uint8)
weights.copy_(torch.tensor(comps))
comps = dequant_weights.detach().float().view(torch.uint8)
weights.copy_(comps)