Skip to content

Commit

Permalink
add hasattr check for torch fp8 dtype (#1985)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
(cherry picked from commit 4dd49a4)
  • Loading branch information
xin3he authored and chensuyue committed Aug 30, 2024
1 parent 0e88b25 commit 65235e4
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,18 @@

FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1}
INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT}
FP8_MAPPING = {
"fp8_e5m2": torch.float8_e5m2,
"fp8_e5m2fnuz": torch.float8_e5m2fnuz,
"fp8_e4m3fn": torch.float8_e4m3fn,
"fp8_e4m3fnuz": torch.float8_e4m3fnuz,
}
if hasattr(torch, "float8_e5m2") and hasattr(torch, "float8_e4m3fn"):
FP8_MAPPING = {
"fp8_e5m2": torch.float8_e5m2,
"fp8_e4m3fn": torch.float8_e4m3fn,
}
if hasattr(torch, "float8_e5m2fnuz") and hasattr(torch, "float8_e4m3fnuz"):
FP8_MAPPING = {
"fp8_e5m2": torch.float8_e5m2,
"fp8_e4m3fn": torch.float8_e4m3fn,
"fp8_e5m2fnuz": torch.float8_e5m2fnuz,
"fp8_e4m3fnuz": torch.float8_e4m3fnuz,
}


def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs):
Expand Down

0 comments on commit 65235e4

Please sign in to comment.