Skip to content

Commit

Permalink
Rename RTNWeightOnlyConfig to RTNConfig (#1551)
Browse files Browse the repository at this point in the history
* Rename RTNWeightOnlyConfig to RTNConfig

Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Jan 19, 2024
1 parent c565e96 commit 941fed3
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_user_model():

# 3.x api
if args.approach == 'weight_only':
from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize
from neural_compressor.torch import RTNConfig, GPTQConfig, quantize
from neural_compressor.torch.utils.utility import get_double_quant_config
weight_sym = True if args.woq_scheme == "sym" else False
double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym)
Expand All @@ -243,9 +243,9 @@ def get_user_model():
"enable_mse_search": args.woq_enable_mse_search,
}
)
quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict)
quant_config = RTNConfig.from_dict(double_quant_config_dict)
else:
quant_config = RTNWeightQuantConfig(
quant_config = RTNConfig(
weight_dtype=args.woq_dtype,
weight_bits=args.woq_bits,
weight_group_size=args.woq_group_size,
Expand All @@ -257,7 +257,7 @@ def get_user_model():
double_quant_sym=args.double_quant_sym,
double_quant_group_size=args.double_quant_group_size,
)
quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32"))
quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32"))
user_model = quantize(
model=user_model, quant_config=quant_config
)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# config name
BASE_CONFIG = "base_config"
COMPOSABLE_CONFIG = "composable_config"
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
RTN = "rtn"
STATIC_QUANT = "static_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def register_algo(name):
Usage example:
@register_algo(name=example_algo)
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
...
Args:
name (str): The name under which the algorithm function will be registered.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

from neural_compressor.torch.quantization import (
quantize,
RTNWeightQuantConfig,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.autotune import autotune, get_default_tune_config
from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config
4 changes: 2 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"):
return int_weight


from neural_compressor.torch.quantization.config import RTNWeightQuantConfig
from neural_compressor.torch.quantization.config import RTNConfig


def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
# TODO (Yi) remove it
enable_full_range = quant_config.enable_full_range
enable_mse_search = quant_config.enable_mse_search
Expand Down
8 changes: 4 additions & 4 deletions neural_compressor/torch/algorithms/weight_only_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import torch

from neural_compressor.common.logger import Logger
from neural_compressor.common.utility import GPTQ, RTN_WEIGHT_ONLY_QUANT
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
from neural_compressor.common.utility import GPTQ, RTN
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
from neural_compressor.torch.utils.utility import fetch_module, register_algo, set_module

logger = Logger().get_logger()


###################### RTN Algo Entry ##################################
@register_algo(name=RTN_WEIGHT_ONLY_QUANT)
@register_algo(name=RTN)
def rtn_quantize_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNWeightQuantConfig], *args, **kwargs
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs
) -> torch.nn.Module:
"""The main entry to apply rtn quantization."""
from .weight_only.rtn import apply_rtn_on_single_module
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from neural_compressor.torch.quantization.quantize import quantize, quantize_dynamic
from neural_compressor.torch.quantization.config import (
RTNWeightQuantConfig,
RTNConfig,
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.common.logger import Logger
from neural_compressor.torch import quantize
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig

logger = Logger().get_logger()

Expand All @@ -33,7 +33,7 @@

def get_default_tune_config() -> TuningConfig:
# TODO use the registered default tuning config in the next PR
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNWeightQuantConfig(weight_bits=[4, 8])])
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])])


def autotune(
Expand Down
24 changes: 9 additions & 15 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
import torch

from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utility import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN_WEIGHT_ONLY_QUANT,
)
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

Expand Down Expand Up @@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
######################## RNT Config ###############################


@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT, priority=PRIORITY_RTN)
class RTNWeightQuantConfig(BaseConfig):
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
class RTNConfig(BaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

supported_configs: List[OperatorConfig] = []
Expand All @@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
"double_quant_sym",
"double_quant_group_size",
]
name = RTN_WEIGHT_ONLY_QUANT
name = RTN

def __init__(
self,
Expand Down Expand Up @@ -137,12 +131,12 @@ def to_dict(self):

@classmethod
def from_dict(cls, config_dict):
return super(RTNWeightQuantConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)
return super(RTNConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
linear_rtn_config = RTNWeightQuantConfig(
linear_rtn_config = RTNConfig(
weight_dtype=["int", "int8", "int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"],
weight_bits=[4, 1, 2, 3, 5, 6, 7, 8],
weight_group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024],
Expand Down Expand Up @@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:


# TODO(Yi) run `register_supported_configs` for all registered config.
RTNWeightQuantConfig.register_supported_configs()
RTNConfig.register_supported_configs()


def get_default_rtn_config() -> RTNWeightQuantConfig:
def get_default_rtn_config() -> RTNConfig:
"""Generate the default rtn config.
Returns:
the default rtn config.
"""
return RTNWeightQuantConfig()
return RTNConfig()


######################## GPTQ Config ###############################
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def register_algo(name):
Usage example:
@register_algo(name=example_algo)
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
...
Args:
Expand Down
18 changes: 9 additions & 9 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _apply_rtn(self, quant_config):
return qmodel

def test_rtn(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

# some tests were skipped to accelerate the CI
rnt_options = {
Expand All @@ -76,7 +76,7 @@ def test_rtn(self):
}
from itertools import product

keys = RTNWeightQuantConfig.params_list
keys = RTNConfig.params_list
for value in product(*rnt_options.values()):
d = dict(zip(keys, value))
if (d["weight_dtype"] == "int" and d["weight_bits"] != 8) or (
Expand All @@ -85,26 +85,26 @@ def test_rtn(self):
or (d["return_int"] and (d["group_dim"] != 1 or d["weight_bits"] != 8))
):
continue
quant_config = RTNWeightQuantConfig(**d)
quant_config = RTNConfig(**d)
self._apply_rtn(quant_config)

def test_rtn_return_type(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

for return_int in [True, False]:
quant_config = RTNWeightQuantConfig(return_int=return_int)
quant_config = RTNConfig(return_int=return_int)
qmodel = self._apply_rtn(quant_config)

def test_rtn_mse_search(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

quant_config = RTNWeightQuantConfig(enable_mse_search=True)
quant_config = RTNConfig(enable_mse_search=True)
qmodel = self._apply_rtn(quant_config)

def test_rtn_recover(self):
from neural_compressor.torch import RTNWeightQuantConfig
from neural_compressor.torch import RTNConfig

quant_config = RTNWeightQuantConfig(return_int=True)
quant_config = RTNConfig(return_int=True)
qmodel = self._apply_rtn(quant_config)
input = torch.randn(4, 8)
# test forward
Expand Down
12 changes: 6 additions & 6 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def setUp(self):
def test_autotune_api(self):
logger.info("test_autotune_api")
from neural_compressor.common.base_tuning import evaluator
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
Expand All @@ -78,7 +78,7 @@ def eval_acc_fn(model) -> float:
def test_autotune_api_2(self):
logger.info("test_autotune_api")
from neural_compressor.common.base_tuning import evaluator
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

def eval_acc_fn(model) -> float:
return 1.0
Expand All @@ -94,17 +94,17 @@ def eval_perf_fn(model) -> float:
},
]

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_fns)
self.assertIsNotNone(best_model)
self.assertEqual(len(evaluator.eval_fn_registry), 2)

@reset_tuning_target
def test_autotune_not_eval_func(self):
logger.info("test_autotune_api")
from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune
from neural_compressor.torch import RTNConfig, TuningConfig, autotune

custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2)
custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2)

# Use assertRaises to check that an AssertionError is raised
with self.assertRaises(AssertionError) as context:
Expand Down
Loading

0 comments on commit 941fed3

Please sign in to comment.