diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 41aeaca7ab3..44283e48fa6 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -40,30 +40,83 @@ # Dictionary to store registered configurations -registered_configs = {} -def register_config(framework_name="None", algo_name=None): +class ConfigRegistry: + registered_configs = {} + + @classmethod + def register_config_impl(cls, framework_name="None", algo_name=None, priority=0): + """Register config decorator. + + The register the configuration classes for different algorithms within specific frameworks. + + Usage example: + @ConfigRegistry.register_config(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=100) + class ExampleAlgorithmConfig: + # Configuration details for the ExampleAlgorithm + + Args: + framework_name: the framework name. Defaults to "None". + algo_name: the algorithm name. Defaults to None. + priority: priority: the priority of the configuration. A larger number indicates a higher priority, + which will be tried first at the auto-tune stage. Defaults to 0. + """ + + def decorator(config_cls): + cls.registered_configs.setdefault(framework_name, {}) + cls.registered_configs[framework_name][algo_name] = {"priority": priority, "cls": config_cls} + return config_cls + + return decorator + + @classmethod + def get_all_configs(cls) -> Dict[str, Dict[str, Dict[str, object]]]: + """Get all registered configurations.""" + return cls.registered_configs + + @classmethod + def get_sorted_configs(cls) -> Dict[str, OrderedDict[str, Dict[str, object]]]: + """Get registered configurations sorted by priority.""" + sorted_configs = OrderedDict() + for framework_name, algos in sorted(cls.registered_configs.items()): + sorted_configs[framework_name] = OrderedDict( + sorted(algos.items(), key=lambda x: x[1]["priority"], reverse=True) + ) + return sorted_configs + + @classmethod + def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: + """Get registered configurations without priority.""" + cls_configs = {} + for framework_name, algos in cls.registered_configs.items(): + cls_configs[framework_name] = {} + for algo_name, config_data in algos.items(): + cls_configs[framework_name][algo_name] = config_data["cls"] + return cls_configs + + +config_registry = ConfigRegistry() + + +def register_config(framework_name="None", algo_name=None, priority=0): """Register config decorator. The register the configuration classes for different algorithms within specific frameworks. Usage example: - @register_config(framework_name="PyTorch", algo_name="ExampleAlgorithm") + @register_config(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=100) class ExampleAlgorithmConfig: # Configuration details for the ExampleAlgorithm Args: framework_name: the framework name. Defaults to "None". algo_name: the algorithm name. Defaults to None. + priority: the priority of the configuration. A larger number indicates a higher priority, + which will be tried first at the auto-tune stage. Defaults to 0. """ - def decorator(config_cls): - registered_configs.setdefault(framework_name, {}) - registered_configs[framework_name][algo_name] = config_cls - return config_cls - - return decorator + return config_registry.register_config_impl(framework_name=framework_name, algo_name=algo_name, priority=priority) class BaseConfig(ABC): diff --git a/neural_compressor/tensorflow/quantization/config.py b/neural_compressor/tensorflow/quantization/config.py index 7577120c4df..72bb1a3bc72 100644 --- a/neural_compressor/tensorflow/quantization/config.py +++ b/neural_compressor/tensorflow/quantization/config.py @@ -22,7 +22,7 @@ import tensorflow as tf -from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs +from neural_compressor.common.base_config import BaseConfig, config_registry, register_config from neural_compressor.common.utility import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE, STATIC_QUANT FRAMEWORK_NAME = "keras" @@ -150,6 +150,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_all_registered_configs() -> Dict[str, BaseConfig]: """Get all registered configs for keras framework.""" + registered_configs = config_registry.get_cls_configs() return registered_configs.get(FRAMEWORK_NAME, {}) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 3a6cbaca6ac..11dc239d050 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,7 +23,7 @@ import torch -from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs +from neural_compressor.common.base_config import BaseConfig, config_registry, register_config from neural_compressor.common.utility import ( DEFAULT_WHITE_LIST, FP8_QUANT, @@ -31,6 +31,7 @@ OP_NAME_OR_MODULE_TYPE, RTN_WEIGHT_ONLY_QUANT, ) +from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger FRAMEWORK_NAME = "torch" @@ -59,7 +60,7 @@ class OperatorConfig(NamedTuple): ######################## RNT Config ############################### -@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT) +@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT, priority=PRIORITY_RTN) class RTNWeightQuantConfig(BaseConfig): """Config class for round-to-nearest weight-only quantization.""" @@ -185,7 +186,7 @@ def get_default_rtn_config() -> RTNWeightQuantConfig: ######################## GPTQ Config ############################### -@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ) +@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ) class GPTQConfig(BaseConfig): """Config class for GPTQ. @@ -403,4 +404,5 @@ def get_default_fp8_qconfig() -> FP8QConfig: ##################### Algo Configs End ################################### def get_all_registered_configs() -> Dict[str, BaseConfig]: + registered_configs = config_registry.get_all_configs() return registered_configs.get(FRAMEWORK_NAME, {}) diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 2a1f1ca599b..0157a42fd47 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -17,7 +17,7 @@ import torch -from neural_compressor.common.base_config import BaseConfig, ComposableConfig, registered_configs +from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.logger import Logger from neural_compressor.torch.quantization.config import FRAMEWORK_NAME from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info @@ -48,6 +48,7 @@ def quantize( The quantized model. """ q_model = model if inplace else copy.deepcopy(model) + registered_configs = config_registry.get_cls_configs() if isinstance(quant_config, dict): quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME]) logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.") @@ -88,6 +89,7 @@ def quantize_dynamic( The quantized model. """ q_model = model if inplace else copy.deepcopy(model) + registered_configs = config_registry.get_cls_configs() if isinstance(quant_config, dict): quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME]) logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.") diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 65c499f8716..74990a5bf42 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -36,3 +36,7 @@ "double_quant_group_size": 256, }, } + +# Setting priorities for algorithms, a higher number indicates a higher priority. +PRIORITY_RTN = 80 +PRIORITY_GPTQ = 90