Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename RTNWeightOnlyConfig to RTNConfig #1551

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_user_model():

# 3.x api
if args.approach == 'weight_only':
from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize
from neural_compressor.torch import RTNConfig, GPTQConfig, quantize
from neural_compressor.torch.utils.utility import get_double_quant_config
weight_sym = True if args.woq_scheme == "sym" else False
double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym)
Expand All @@ -243,9 +243,9 @@ def get_user_model():
"enable_mse_search": args.woq_enable_mse_search,
}
)
quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict)
quant_config = RTNConfig.from_dict(double_quant_config_dict)
else:
quant_config = RTNWeightQuantConfig(
quant_config = RTNConfig(
weight_dtype=args.woq_dtype,
weight_bits=args.woq_bits,
weight_group_size=args.woq_group_size,
Expand All @@ -257,7 +257,7 @@ def get_user_model():
double_quant_sym=args.double_quant_sym,
double_quant_group_size=args.double_quant_group_size,
)
quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32"))
quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32"))
user_model = quantize(
model=user_model, quant_config=quant_config
)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# config name
BASE_CONFIG = "base_config"
COMPOSABLE_CONFIG = "composable_config"
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
RTN = "rtn"
STATIC_QUANT = "static_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def register_algo(name):

Usage example:
@register_algo(name=example_algo)
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
...
Args:
name (str): The name under which the algorithm function will be registered.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

from neural_compressor.torch.quantization import (
quantize,
RTNWeightQuantConfig,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.autotune import autotune, get_default_tune_config
from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config
4 changes: 2 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"):
return int_weight


from neural_compressor.torch.quantization.config import RTNWeightQuantConfig
from neural_compressor.torch.quantization.config import RTNConfig


def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
# TODO (Yi) remove it
enable_full_range = quant_config.enable_full_range
enable_mse_search = quant_config.enable_mse_search
Expand Down
8 changes: 4 additions & 4 deletions neural_compressor/torch/algorithms/weight_only_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import torch

from neural_compressor.common.logger import Logger
from neural_compressor.common.utility import GPTQ, RTN_WEIGHT_ONLY_QUANT
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
from neural_compressor.common.utility import GPTQ, RTN
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
from neural_compressor.torch.utils.utility import fetch_module, register_algo, set_module

logger = Logger().get_logger()


###################### RTN Algo Entry ##################################
@register_algo(name=RTN_WEIGHT_ONLY_QUANT)
@register_algo(name=RTN)
def rtn_quantize_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNWeightQuantConfig], *args, **kwargs
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs
) -> torch.nn.Module:
"""The main entry to apply rtn quantization."""
from .weight_only.rtn import apply_rtn_on_single_module
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from neural_compressor.torch.quantization.quantize import quantize, quantize_dynamic
from neural_compressor.torch.quantization.config import (
RTNWeightQuantConfig,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.common.logger import Logger
from neural_compressor.torch import quantize
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig

logger = Logger().get_logger()

Expand All @@ -33,7 +33,7 @@

def get_default_tune_config() -> TuningConfig:
# TODO use the registered default tuning config in the next PR
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNWeightQuantConfig(weight_bits=[4, 8])])
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])])


def autotune(
Expand Down
24 changes: 9 additions & 15 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
import torch

from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utility import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN_WEIGHT_ONLY_QUANT,
)
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

Expand Down Expand Up @@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
######################## RNT Config ###############################


@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT, priority=PRIORITY_RTN)
class RTNWeightQuantConfig(BaseConfig):
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
class RTNConfig(BaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

supported_configs: List[OperatorConfig] = []
Expand All @@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
"double_quant_sym",
"double_quant_group_size",
]
name = RTN_WEIGHT_ONLY_QUANT
name = RTN

def __init__(
self,
Expand Down Expand Up @@ -137,12 +131,12 @@ def to_dict(self):

@classmethod
def from_dict(cls, config_dict):
return super(RTNWeightQuantConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)
return super(RTNConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
linear_rtn_config = RTNWeightQuantConfig(
linear_rtn_config = RTNConfig(
weight_dtype=["int", "int8", "int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"],
weight_bits=[4, 1, 2, 3, 5, 6, 7, 8],
weight_group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024],
Expand Down Expand Up @@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:


# TODO(Yi) run `register_supported_configs` for all registered config.
RTNWeightQuantConfig.register_supported_configs()
RTNConfig.register_supported_configs()


def get_default_rtn_config() -> RTNWeightQuantConfig:
def get_default_rtn_config() -> RTNConfig:
"""Generate the default rtn config.

Returns:
the default rtn config.
"""
return RTNWeightQuantConfig()
return RTNConfig()


######################## GPTQ Config ###############################
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def register_algo(name):

Usage example:
@register_algo(name=example_algo)
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
...

Args:
Expand Down
18 changes: 9 additions & 9 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _apply_rtn(self, quant_config):
return qmodel

def test_rtn(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

# some tests were skipped to accelerate the CI
rnt_options = {
Expand All @@ -76,7 +76,7 @@ def test_rtn(self):
}
from itertools import product

keys = RTNWeightQuantConfig.params_list
keys = RTNConfig.params_list
for value in product(*rnt_options.values()):
d = dict(zip(keys, value))
if (d["weight_dtype"] == "int" and d["weight_bits"] != 8) or (
Expand All @@ -85,26 +85,26 @@ def test_rtn(self):
or (d["return_int"] and (d["group_dim"] != 1 or d["weight_bits"] != 8))
):
continue
quant_config = RTNWeightQuantConfig(**d)
quant_config = RTNConfig(**d)
self._apply_rtn(quant_config)

def test_rtn_return_type(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

for return_int in [True, False]:
quant_config = RTNWeightQuantConfig(return_int=return_int)
quant_config = RTNConfig(return_int=return_int)
qmodel = self._apply_rtn(quant_config)

def test_rtn_mse_search(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

quant_config = RTNWeightQuantConfig(enable_mse_search=True)
quant_config = RTNConfig(enable_mse_search=True)
qmodel = self._apply_rtn(quant_config)

def test_rtn_recover(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

quant_config = RTNWeightQuantConfig(return_int=True)
quant_config = RTNConfig(return_int=True)
qmodel = self._apply_rtn(quant_config)
input = torch.randn(4, 8)
# test forward
Expand Down
12 changes: 6 additions & 6 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def setUp(self):
def test_autotune_api(self):
logger.info("test_autotune_api")
from neural_compressor.common.base_tuning import evaluator
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
Expand All @@ -78,7 +78,7 @@ def eval_acc_fn(model) -> float:
def test_autotune_api_2(self):
logger.info("test_autotune_api")
from neural_compressor.common.base_tuning import evaluator
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

def eval_acc_fn(model) -> float:
return 1.0
Expand All @@ -94,17 +94,17 @@ def eval_perf_fn(model) -> float:
},
]

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_fns)
self.assertIsNotNone(best_model)
self.assertEqual(len(evaluator.eval_fn_registry), 2)

@reset_tuning_target
def test_autotune_not_eval_func(self):
logger.info("test_autotune_api")
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)

# Use assertRaises to check that an AssertionError is raised
with self.assertRaises(AssertionError) as context:
Expand Down
Loading