Skip to content

Commit

Permalink
support fp8 cast WOQ (#1746)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Apr 26, 2024
1 parent 522cfe3 commit 57ed613
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
8 changes: 7 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from neural_compressor.torch.utils import get_device, logger, set_module
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .utility import quant_tensor, search_clip
from .utility import cast_fp8, quant_tensor, search_clip


@torch.no_grad()
Expand Down Expand Up @@ -100,6 +100,12 @@ def rtn_quantize(
dtype = weight_config[name].get("dtype", "int")
if dtype == "fp32":
continue
### FP8 cast part
if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]:
logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name))
m.weight = cast_fp8(m.weight, dtype, use_qdq=True)
continue
####
logger.debug("Apply RTN on module %s.", name)
bits = weight_config[name].get("bits", 4)
group_size = weight_config[name]["group_size"]
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@

FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1}
INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT}
FP8_MAPPING = {
"fp8_e5m2": torch.float8_e5m2,
"fp8_e5m2fnuz": torch.float8_e5m2fnuz,
"fp8_e4m3fn": torch.float8_e4m3fn,
"fp8_e4m3fnuz": torch.float8_e4m3fnuz,
}


def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs):
Expand Down Expand Up @@ -121,6 +127,17 @@ def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs)
return tensor.mul_(scale)


def cast_fp8(tensor, dtype="fp8_e4m3fn", use_qdq=True):
torch_dtype = FP8_MAPPING[dtype]
if not use_qdq: # pragma: no cover
return tensor.to(torch_dtype)
else:
orig_dtype = tensor.dtype
fp8_tensor = tensor.to(torch_dtype)
tensor.copy_(fp8_tensor.to(orig_dtype))
return tensor


def qdq_weight_asym(weight, bits=4, quantile=1.0, return_int=False, **kwargs):
"""Quant and dequant tensor with asym schema.
Expand Down
14 changes: 13 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,19 @@ def __init__(
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
linear_rtn_config = RTNConfig(
dtype=["int", "int8", "int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"],
dtype=[
"int",
"int8",
"int4",
"nf4",
"fp4",
"fp4_e2m1_bnb",
"fp4_e2m1",
"fp8_e5m2",
"fp8_e5m2fnuz",
"fp8_e4m3fn",
"fp8_e4m3fnuz",
],
bits=[4, 1, 2, 3, 5, 6, 7, 8],
use_sym=[True, False],
group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024],
Expand Down
5 changes: 4 additions & 1 deletion test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def test_export_compressed_model(self, dtype):
out1, out2
), "Exporting compressed model should have the same output as quantized model. Please double check"

@pytest.mark.parametrize("dtype", ["int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"])
@pytest.mark.parametrize(
"dtype",
["int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"],
)
def test_dtype_params(self, dtype):
model = copy.deepcopy(self.tiny_gptj)
quant_config = RTNConfig(
Expand Down

0 comments on commit 57ed613

Please sign in to comment.