diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py index 1cf343e1a22..b64146a75db 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -44,6 +44,14 @@ class MeasureExclude(Flag): PARAMS = auto() ALL = auto() +class SupportedFp8(Enum): + E4M3 = torch.float8_e4m3fn + E5M2 = torch.float8_e5m2 + +class HpDtype(Enum): + BF16 = torch.bfloat16 + FP16 = torch.float16 + FP32 = torch.float32 class ScaleMethod(Enum): MAX = 1 @@ -69,6 +77,13 @@ def set_hqt_config(mod, config): mod.__hqt_config__ = config +def _get_enum_from_string(EnumClass, str, key): + if not hasattr(EnumClass, str.upper()): + raise ValueError( + f"Invalid '{key}' value in custom config ('{str}'). Enter one of {[m.name for m in EnumClass]}") + return EnumClass[str.upper()] + + @dataclass class Fp8cfg: cfg: Mapping[str, Any] @@ -84,7 +99,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg: }, # types and names to not be quantized "allowlist": { "names": [], - "types": ("torch.nn.Linear", "torch.nn.Conv2d", "BMM"), + "types": (), }, # types and names to be quantized. Allowlist by names is not yet implemented "mode": QuantMode.QUANTIZE, # Quantize or Measure "scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with @@ -104,79 +119,19 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg: # go over all user-defined keys from json, handle various cases for keys in custom_config: if keys == "mode": - if custom_config[keys] == "NONE": - custom_config[keys] = QuantMode.NONE - elif custom_config[keys] == "QUANTIZE": - custom_config[keys] = QuantMode.QUANTIZE - elif custom_config[keys] == "MEASURE": - custom_config[keys] = QuantMode.MEASURE - elif custom_config[keys] == "SHAPE": - custom_config[keys] = QuantMode.SHAPE - else: - raise ValueError("invalid mode in custom config. Enter Quantize or Measure") + custom_config[keys] = _get_enum_from_string(QuantMode, custom_config[keys], keys) if keys == "measure_exclude": - if custom_config[keys] == "NONE": - custom_config[keys] = MeasureExclude.NONE - elif custom_config[keys] == "OUTPUT": - custom_config[keys] = MeasureExclude.OUTPUT - elif custom_config[keys] == "INPUT": - custom_config[keys] = MeasureExclude.INPUT - elif custom_config[keys] == "ALL": - custom_config[keys] = MeasureExclude.ALL - else: - raise ValueError("invalid measure exclude value in custom config. Enter OUTPUT or NONE") + custom_config[keys] = _get_enum_from_string(MeasureExclude, custom_config[keys], keys) if keys == "fp8_config": - if custom_config[keys].lower() == "e4m3": - custom_config[keys] = torch.float8_e4m3fn - - elif custom_config[keys].lower() == "e5m2": - custom_config[keys] = torch.float8_e5m2 - else: - raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2") + custom_config[keys] = _get_enum_from_string(SupportedFp8, custom_config[keys], keys).value if keys == "hp_dtype": - if custom_config[keys].lower() == "bf16": - custom_config[keys] = torch.bfloat16 - elif custom_config[keys].lower() == "fp16": - custom_config[keys] = torch.float16 - elif custom_config[keys].lower() == "fp32": - custom_config[keys] = torch.float32 - else: - raise ValueError("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32") + custom_config[keys] = _get_enum_from_string(HpDtype, custom_config[keys], keys).value if keys == "scale_method": - if custom_config[keys].lower() == "unit_scale": - custom_config[keys] = ScaleMethod.UNIT_SCALE - elif custom_config[keys].lower() == "max": - custom_config[keys] = ScaleMethod.MAX - elif custom_config[keys].lower() == "maxabs_hw": - custom_config[keys] = ScaleMethod.MAXABS_HW - elif custom_config[keys].lower() == "maxabs_pow2": - custom_config[keys] = ScaleMethod.MAXABS_POW2 - elif custom_config[keys].lower() == "maxabs_hw_opt_weight": - custom_config[keys] = ScaleMethod.MAXABS_HW_OPT_WEIGHT - elif custom_config[keys].lower() == "maxabs_pow2_opt_weight": - custom_config[keys] = ScaleMethod.MAXABS_POW2_OPT_WEIGHT - elif custom_config[keys].lower() == "smoothquant_weights_output_channel_maxabs_pow2": - custom_config[keys] = ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 - elif custom_config[keys].lower() == "weaksmoothquant_weights_output_channel_maxabs_pow2": - custom_config[keys] = ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 - elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_maxabs_pow2": - custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 - elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_opt_pow2": - custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 - elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_maxabs_pow2": - custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 - elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_opt_pow2": - custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 - elif custom_config[keys].lower() == "smoothquant_opt": - custom_config[keys] = ScaleMethod.SMOOTHQUANT_OPT - else: - raise ValueError( - f'Invalid fp8_config in custom config ({custom_config[keys]}). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]' - ) + custom_config[keys] = _get_enum_from_string(ScaleMethod, custom_config[keys], keys) if keys == "ignore_modules_wo_measures": custom_config[keys] = custom_config[keys].lower() == "true"