Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt 3.x config for static quant and smooth quant #1568

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
COMPOSABLE_CONFIG = "composable_config"
RTN = "rtn"
STATIC_QUANT = "static_quant"
SMOOTH_QUANT = "smooth_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"

Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)
195 changes: 194 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
import torch

from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
from neural_compressor.common.utils import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN,
SMOOTH_QUANT,
STATIC_QUANT,
)
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

Expand Down Expand Up @@ -311,6 +319,191 @@ def get_default_gptq_config() -> GPTQConfig:
return GPTQConfig()


######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
"""Config class for static quantization."""

name = STATIC_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "minmax",
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init Static Quant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_static_config = StaticQuantConfig()
operators = [torch.nn.Linear, torch.nn.functional.linear]
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
StaticQuantConfig.register_supported_configs()


def get_default_static_config() -> StaticQuantConfig:
"""Generate the default static quant config.

Returns:
the default static quant config.
"""
return StaticQuantConfig()


######################## Smooth Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
class SmoothQuantConfig(BaseConfig):
"""Config class for smooth quantization."""

name = SMOOTH_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
"alpha",
"folding",
"scale_sharing",
"auto_alpha_args",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "minmax",
alpha: float = 0.5,
folding: bool = False,
# below for autotune
scale_sharing: bool = False,
init_alpha: float = 0.5,
alpha_min: float = 0.0,
alpha_max: float = 1.0,
alpha_step: float = 0.1,
shared_criterion: str = "max",
enable_blockwise_loss: bool = False,
auto_alpha_args: dict = None,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init SmoothQuant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self.alpha = alpha
self.folding = folding
# below for autotune
self.scale_sharing = scale_sharing
self.init_alpha = init_alpha
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.alpha_step = alpha_step
self.shared_criterion = shared_criterion
self.enable_blockwise_loss = enable_blockwise_loss
self.auto_alpha_args = {
"init_alpha": self.init_alpha,
"alpha_min": self.alpha_min,
"alpha_max": self.alpha_max,
"alpha_step": self.alpha_step,
"shared_criterion": self.shared_criterion,
"enable_blockwise_loss": self.enable_blockwise_loss,
}
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_sq_config = SmoothQuantConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
SmoothQuantConfig.register_supported_configs()


def get_default_sq_config() -> SmoothQuantConfig:
"""Generate the default smoothquant config.

Returns:
the default smoothquant config.
"""
return SmoothQuantConfig()


######################## FP8 Config ###############################
if is_hpex_avaliable():

Expand Down
Loading