Skip to content

Commit

Permalink
[SW-186675] Update default configuration of 'allowlist'
Browse files Browse the repository at this point in the history
Defined default allowlist types to be empty - allows quantization of all models
Refactor parse function to more dynamic code and consistency

Change-Id: I6c8a14cb7ca6830927e5c5b7476e4b03335456aa
  • Loading branch information
Tiefen-boop authored and Eran Geva committed Aug 4, 2024
1 parent 3f1d5c0 commit 55e1387
Showing 1 changed file with 21 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ class MeasureExclude(Flag):
PARAMS = auto()
ALL = auto()

class SupportedFp8(Enum):
E4M3 = torch.float8_e4m3fn
E5M2 = torch.float8_e5m2

class HpDtype(Enum):
BF16 = torch.bfloat16
FP16 = torch.float16
FP32 = torch.float32

class ScaleMethod(Enum):
MAX = 1
Expand All @@ -69,6 +77,13 @@ def set_hqt_config(mod, config):
mod.__hqt_config__ = config


def _get_enum_from_string(EnumClass, str, key):
if not hasattr(EnumClass, str.upper()):
raise ValueError(
f"Invalid '{key}' value in custom config ('{str}'). Enter one of {[m.name for m in EnumClass]}")
return EnumClass[str.upper()]


@dataclass
class Fp8cfg:
cfg: Mapping[str, Any]
Expand All @@ -84,7 +99,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
}, # types and names to not be quantized
"allowlist": {
"names": [],
"types": ("torch.nn.Linear", "torch.nn.Conv2d", "BMM"),
"types": (),
}, # types and names to be quantized. Allowlist by names is not yet implemented
"mode": QuantMode.QUANTIZE, # Quantize or Measure
"scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with
Expand All @@ -104,79 +119,19 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
# go over all user-defined keys from json, handle various cases
for keys in custom_config:
if keys == "mode":
if custom_config[keys] == "NONE":
custom_config[keys] = QuantMode.NONE
elif custom_config[keys] == "QUANTIZE":
custom_config[keys] = QuantMode.QUANTIZE
elif custom_config[keys] == "MEASURE":
custom_config[keys] = QuantMode.MEASURE
elif custom_config[keys] == "SHAPE":
custom_config[keys] = QuantMode.SHAPE
else:
raise ValueError("invalid mode in custom config. Enter Quantize or Measure")
custom_config[keys] = _get_enum_from_string(QuantMode, custom_config[keys], keys)

if keys == "measure_exclude":
if custom_config[keys] == "NONE":
custom_config[keys] = MeasureExclude.NONE
elif custom_config[keys] == "OUTPUT":
custom_config[keys] = MeasureExclude.OUTPUT
elif custom_config[keys] == "INPUT":
custom_config[keys] = MeasureExclude.INPUT
elif custom_config[keys] == "ALL":
custom_config[keys] = MeasureExclude.ALL
else:
raise ValueError("invalid measure exclude value in custom config. Enter OUTPUT or NONE")
custom_config[keys] = _get_enum_from_string(MeasureExclude, custom_config[keys], keys)

if keys == "fp8_config":
if custom_config[keys].lower() == "e4m3":
custom_config[keys] = torch.float8_e4m3fn

elif custom_config[keys].lower() == "e5m2":
custom_config[keys] = torch.float8_e5m2
else:
raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2")
custom_config[keys] = _get_enum_from_string(SupportedFp8, custom_config[keys], keys).value

if keys == "hp_dtype":
if custom_config[keys].lower() == "bf16":
custom_config[keys] = torch.bfloat16
elif custom_config[keys].lower() == "fp16":
custom_config[keys] = torch.float16
elif custom_config[keys].lower() == "fp32":
custom_config[keys] = torch.float32
else:
raise ValueError("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32")
custom_config[keys] = _get_enum_from_string(HpDtype, custom_config[keys], keys).value

if keys == "scale_method":
if custom_config[keys].lower() == "unit_scale":
custom_config[keys] = ScaleMethod.UNIT_SCALE
elif custom_config[keys].lower() == "max":
custom_config[keys] = ScaleMethod.MAX
elif custom_config[keys].lower() == "maxabs_hw":
custom_config[keys] = ScaleMethod.MAXABS_HW
elif custom_config[keys].lower() == "maxabs_pow2":
custom_config[keys] = ScaleMethod.MAXABS_POW2
elif custom_config[keys].lower() == "maxabs_hw_opt_weight":
custom_config[keys] = ScaleMethod.MAXABS_HW_OPT_WEIGHT
elif custom_config[keys].lower() == "maxabs_pow2_opt_weight":
custom_config[keys] = ScaleMethod.MAXABS_POW2_OPT_WEIGHT
elif custom_config[keys].lower() == "smoothquant_weights_output_channel_maxabs_pow2":
custom_config[keys] = ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
elif custom_config[keys].lower() == "weaksmoothquant_weights_output_channel_maxabs_pow2":
custom_config[keys] = ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_maxabs_pow2":
custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2
elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_opt_pow2":
custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2
elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_maxabs_pow2":
custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2
elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_opt_pow2":
custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2
elif custom_config[keys].lower() == "smoothquant_opt":
custom_config[keys] = ScaleMethod.SMOOTHQUANT_OPT
else:
raise ValueError(
f'Invalid fp8_config in custom config ({custom_config[keys]}). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]'
)
custom_config[keys] = _get_enum_from_string(ScaleMethod, custom_config[keys], keys)

if keys == "ignore_modules_wo_measures":
custom_config[keys] = custom_config[keys].lower() == "true"
Expand Down

0 comments on commit 55e1387

Please sign in to comment.