Skip to content

Commit

Permalink
Add Config Registry for autotune (#1543)
Browse files Browse the repository at this point in the history
* add config registery for common & torch

Signed-off-by: Kaihui-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update torch quant algo priority

Signed-off-by: Kaihui-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add config registry for tf

Signed-off-by: Kaihui-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add get_cls_config

Signed-off-by: Kaihui-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add typing check

Signed-off-by: Kaihui-intel <[email protected]>

---------

Signed-off-by: Kaihui-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Kaihui-intel and pre-commit-ci[bot] authored Jan 17, 2024
1 parent e951e7a commit 09eb5dd
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 14 deletions.
71 changes: 62 additions & 9 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,83 @@


# Dictionary to store registered configurations
registered_configs = {}


def register_config(framework_name="None", algo_name=None):
class ConfigRegistry:
registered_configs = {}

@classmethod
def register_config_impl(cls, framework_name="None", algo_name=None, priority=0):
"""Register config decorator.
The register the configuration classes for different algorithms within specific frameworks.
Usage example:
@ConfigRegistry.register_config(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=100)
class ExampleAlgorithmConfig:
# Configuration details for the ExampleAlgorithm
Args:
framework_name: the framework name. Defaults to "None".
algo_name: the algorithm name. Defaults to None.
priority: priority: the priority of the configuration. A larger number indicates a higher priority,
which will be tried first at the auto-tune stage. Defaults to 0.
"""

def decorator(config_cls):
cls.registered_configs.setdefault(framework_name, {})
cls.registered_configs[framework_name][algo_name] = {"priority": priority, "cls": config_cls}
return config_cls

return decorator

@classmethod
def get_all_configs(cls) -> Dict[str, Dict[str, Dict[str, object]]]:
"""Get all registered configurations."""
return cls.registered_configs

@classmethod
def get_sorted_configs(cls) -> Dict[str, OrderedDict[str, Dict[str, object]]]:
"""Get registered configurations sorted by priority."""
sorted_configs = OrderedDict()
for framework_name, algos in sorted(cls.registered_configs.items()):
sorted_configs[framework_name] = OrderedDict(
sorted(algos.items(), key=lambda x: x[1]["priority"], reverse=True)
)
return sorted_configs

@classmethod
def get_cls_configs(cls) -> Dict[str, Dict[str, object]]:
"""Get registered configurations without priority."""
cls_configs = {}
for framework_name, algos in cls.registered_configs.items():
cls_configs[framework_name] = {}
for algo_name, config_data in algos.items():
cls_configs[framework_name][algo_name] = config_data["cls"]
return cls_configs


config_registry = ConfigRegistry()


def register_config(framework_name="None", algo_name=None, priority=0):
"""Register config decorator.
The register the configuration classes for different algorithms within specific frameworks.
Usage example:
@register_config(framework_name="PyTorch", algo_name="ExampleAlgorithm")
@register_config(framework_name=FRAMEWORK_NAME, algo_name=ExampleAlgorithm, priority=100)
class ExampleAlgorithmConfig:
# Configuration details for the ExampleAlgorithm
Args:
framework_name: the framework name. Defaults to "None".
algo_name: the algorithm name. Defaults to None.
priority: the priority of the configuration. A larger number indicates a higher priority,
which will be tried first at the auto-tune stage. Defaults to 0.
"""

def decorator(config_cls):
registered_configs.setdefault(framework_name, {})
registered_configs[framework_name][algo_name] = config_cls
return config_cls

return decorator
return config_registry.register_config_impl(framework_name=framework_name, algo_name=algo_name, priority=priority)


class BaseConfig(ABC):
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/tensorflow/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import tensorflow as tf

from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs
from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE, STATIC_QUANT

FRAMEWORK_NAME = "keras"
Expand Down Expand Up @@ -150,6 +150,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

def get_all_registered_configs() -> Dict[str, BaseConfig]:
"""Get all registered configs for keras framework."""
registered_configs = config_registry.get_cls_configs()
return registered_configs.get(FRAMEWORK_NAME, {})


Expand Down
8 changes: 5 additions & 3 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@

import torch

from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs
from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utility import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN_WEIGHT_ONLY_QUANT,
)
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

FRAMEWORK_NAME = "torch"
Expand Down Expand Up @@ -59,7 +60,7 @@ class OperatorConfig(NamedTuple):
######################## RNT Config ###############################


@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT)
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT, priority=PRIORITY_RTN)
class RTNWeightQuantConfig(BaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

Expand Down Expand Up @@ -185,7 +186,7 @@ def get_default_rtn_config() -> RTNWeightQuantConfig:


######################## GPTQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ)
@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ)
class GPTQConfig(BaseConfig):
"""Config class for GPTQ.
Expand Down Expand Up @@ -403,4 +404,5 @@ def get_default_fp8_qconfig() -> FP8QConfig:
##################### Algo Configs End ###################################

def get_all_registered_configs() -> Dict[str, BaseConfig]:
registered_configs = config_registry.get_all_configs()
return registered_configs.get(FRAMEWORK_NAME, {})
4 changes: 3 additions & 1 deletion neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from neural_compressor.common.base_config import BaseConfig, ComposableConfig, registered_configs
from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
from neural_compressor.common.logger import Logger
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME
from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info
Expand Down Expand Up @@ -48,6 +48,7 @@ def quantize(
The quantized model.
"""
q_model = model if inplace else copy.deepcopy(model)
registered_configs = config_registry.get_cls_configs()
if isinstance(quant_config, dict):
quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME])
logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.")
Expand Down Expand Up @@ -88,6 +89,7 @@ def quantize_dynamic(
The quantized model.
"""
q_model = model if inplace else copy.deepcopy(model)
registered_configs = config_registry.get_cls_configs()
if isinstance(quant_config, dict):
quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME])
logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.")
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@
"double_quant_group_size": 256,
},
}

# Setting priorities for algorithms, a higher number indicates a higher priority.
PRIORITY_RTN = 80
PRIORITY_GPTQ = 90

0 comments on commit 09eb5dd

Please sign in to comment.