diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index 615c5f26ffd..c92f916e6f7 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -29,6 +29,7 @@ COMPOSABLE_CONFIG = "composable_config" RTN = "rtn" STATIC_QUANT = "static_quant" +SMOOTH_QUANT = "smooth_quant" GPTQ = "gptq" FP8_QUANT = "fp8_quant" diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index 81f131ca114..5fe8d73cc8c 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -21,6 +21,10 @@ get_default_rtn_config, GPTQConfig, get_default_gptq_config, + StaticQuantConfig, + get_default_static_config, + SmoothQuantConfig, + get_default_sq_config, ) from neural_compressor.common.base_tuning import TuningConfig diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index c78a7b0552e..b287c5ae2d6 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -18,4 +18,8 @@ get_default_rtn_config, GPTQConfig, get_default_gptq_config, + StaticQuantConfig, + get_default_static_config, + SmoothQuantConfig, + get_default_sq_config, ) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index fef6bf97c19..86c5b2d458e 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -24,7 +24,15 @@ import torch from neural_compressor.common.base_config import BaseConfig, config_registry, register_config -from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN +from neural_compressor.common.utils import ( + DEFAULT_WHITE_LIST, + FP8_QUANT, + GPTQ, + OP_NAME_OR_MODULE_TYPE, + RTN, + SMOOTH_QUANT, + STATIC_QUANT, +) from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger @@ -282,6 +290,191 @@ def get_default_gptq_config() -> GPTQConfig: return GPTQConfig() +######################## Static Quant Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT) +class StaticQuantConfig(BaseConfig): + """Config class for static quantization.""" + + name = STATIC_QUANT + params_list = [ + "w_dtype", + "w_sym", + "w_granularity", + "w_algo", + "act_dtype", + "act_sym", + "act_granularity", + "act_algo", + ] + supported_configs: List[OperatorConfig] = [] + + def __init__( + self, + w_dtype: str = "int8", + w_sym: bool = True, + w_granularity: str = "per_channel", + w_algo: str = "minmax", + act_dtype: str = "uint8", + act_sym: bool = False, + act_granularity: str = "per_tensor", + act_algo: str = "kl", + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init Static Quant Configs.""" + super().__init__(white_list=white_list) + self.w_dtype = w_dtype + self.w_sym = w_sym + self.w_granularity = w_granularity + self.w_algo = w_algo + self.act_dtype = act_dtype + self.act_sym = act_sym + self.act_granularity = act_granularity + self.act_algo = act_algo + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + # TODO(Yi) + linear_static_config = StaticQuantConfig() + operators = [torch.nn.Linear] + supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators)) + cls.supported_configs = supported_configs + + @staticmethod + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + white_list = (torch.nn.Linear,) + filter_result = [] + for op_name, module in model.named_modules(): + if isinstance(module, white_list): + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + + +# TODO(Yi) run `register_supported_configs` for all registered config. +StaticQuantConfig.register_supported_configs() + + +def get_default_static_config() -> StaticQuantConfig: + """Generate the default static quant config. + + Returns: + the default static quant config. + """ + return StaticQuantConfig() + + +######################## Smooth Quant Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT) +class SmoothQuantConfig(BaseConfig): + """Config class for smooth quantization.""" + + name = SMOOTH_QUANT + params_list = [ + "w_dtype", + "w_sym", + "w_granularity", + "w_algo", + "act_dtype", + "act_sym", + "act_granularity", + "act_algo", + "alpha", + "folding", + "scale_sharing", + "auto_alpha_args", + ] + supported_configs: List[OperatorConfig] = [] + + def __init__( + self, + w_dtype: str = "int8", + w_sym: bool = True, + w_granularity: str = "per_channel", + w_algo: str = "minmax", + act_dtype: str = "uint8", + act_sym: bool = False, + act_granularity: str = "per_tensor", + act_algo: str = "kl", + alpha: float = 0.5, + folding: bool = False, + # below for autotune + scale_sharing: bool = False, + init_alpha: float = 0.5, + alpha_min: float = 0.0, + alpha_max: float = 1.0, + alpha_step: float = 0.1, + shared_criterion: str = "max", + enable_blockwise_loss: bool = False, + auto_alpha_args: dict = None, + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init SmoothQuant Configs.""" + super().__init__(white_list=white_list) + self.w_dtype = w_dtype + self.w_sym = w_sym + self.w_granularity = w_granularity + self.w_algo = w_algo + self.act_dtype = act_dtype + self.act_sym = act_sym + self.act_granularity = act_granularity + self.act_algo = act_algo + self.alpha = alpha + self.folding = folding + # below for autotune + self.scale_sharing = scale_sharing + self.init_alpha = init_alpha + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.alpha_step = alpha_step + self.shared_criterion = shared_criterion + self.enable_blockwise_loss = enable_blockwise_loss + self.auto_alpha_args = { + "init_alpha": self.init_alpha, + "alpha_min": self.alpha_min, + "alpha_max": self.alpha_max, + "alpha_step": self.alpha_step, + "shared_criterion": self.shared_criterion, + "enable_blockwise_loss": self.enable_blockwise_loss, + } + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + # TODO(Yi) + linear_sq_config = SmoothQuantConfig() + operators = [torch.nn.Linear] + supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators)) + cls.supported_configs = supported_configs + + @staticmethod + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + white_list = (torch.nn.Linear,) + filter_result = [] + for op_name, module in model.named_modules(): + if isinstance(module, white_list): + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + + +# TODO(Yi) run `register_supported_configs` for all registered config. +SmoothQuantConfig.register_supported_configs() + + +def get_default_sq_config() -> SmoothQuantConfig: + """Generate the default smoothquant config. + + Returns: + the default smoothquant config. + """ + return SmoothQuantConfig() + + ######################## FP8 Config ############################### if is_hpex_avaliable(): diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 0e74fa56300..ff1012dc1ed 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -321,6 +321,22 @@ def test_gptq_config(self): gptq_config2 = GPTQConfig.from_dict(quant_config_dict["gptq"]) self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict()) + def test_static_quant_config(self): + from neural_compressor.torch.quantization import StaticQuantConfig + + static_config1 = StaticQuantConfig(w_dtype="int8", act_sym=True, act_algo="minmax") + quant_config_dict = {"static": {"w_dtype": "int8", "act_sym": True, "act_algo": "minmax"}} + static_config2 = StaticQuantConfig.from_dict(quant_config_dict["static"]) + self.assertEqual(static_config1.to_dict(), static_config2.to_dict()) + + def test_smooth_quant_config(self): + from neural_compressor.torch.quantization import SmoothQuantConfig + + sq_config1 = SmoothQuantConfig(alpha=0.8, folding=True) + quant_config_dict = {"sq": {"alpha": 0.8, "folding": True}} + sq_config2 = SmoothQuantConfig.from_dict(quant_config_dict["sq"]) + self.assertEqual(sq_config1.to_dict(), sq_config2.to_dict()) + class TestQuantConfigForAutotune(unittest.TestCase): def test_expand_config(self):