diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 06818699f5a..096e7d588c4 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -230,8 +230,8 @@ def get_user_model(): # 3.x api if args.approach == 'weight_only': - from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize - from neural_compressor.torch.utils.utility import get_double_quant_config + from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize + from neural_compressor.torch.utils import get_double_quant_config weight_sym = True if args.woq_scheme == "sym" else False double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym) @@ -243,9 +243,9 @@ def get_user_model(): "enable_mse_search": args.woq_enable_mse_search, } ) - quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict) + quant_config = RTNConfig.from_dict(double_quant_config_dict) else: - quant_config = RTNWeightQuantConfig( + quant_config = RTNConfig( weight_dtype=args.woq_dtype, weight_bits=args.woq_bits, weight_group_size=args.woq_group_size, @@ -257,7 +257,7 @@ def get_user_model(): double_quant_sym=args.double_quant_sym, double_quant_group_size=args.double_quant_group_size, ) - quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32")) + quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32")) user_model = quantize( model=user_model, quant_config=quant_config ) diff --git a/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py index bd73fddd94d..3c47c6f1bbb 100644 --- a/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py @@ -127,7 +127,6 @@ def __init__( dtype=self.float_type, ).to(device), ) - self.scales = self.scales.T self.register_buffer( "qweight", torch.zeros( @@ -135,7 +134,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qweight = self.qweight.T self.register_buffer( "qzeros", torch.zeros( @@ -143,7 +141,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qzeros = self.qzeros.T self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: self.compression_dtype = compression_dtype @@ -193,6 +190,10 @@ def __init__( self.bias = None def pack(self, int_weight, scale, zp, bias): + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -264,8 +265,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -280,7 +281,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -290,10 +291,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -307,7 +308,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index 7376992b9fb..9a3cd6361f2 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -327,9 +327,9 @@ def __init__( def pack(self, int_weight, scale, zp, bias, g_idx=None): if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -411,8 +411,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -427,7 +427,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -437,10 +437,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -454,7 +454,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index b029d137e41..0263ca1c845 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -429,7 +429,7 @@ def rtn_quantize( if num_bits <= 0: logger.info(f"Skip {name}") continue - weight = m.weight.T if group_dim == 0 else m.weight + weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight if enable_mse_search: quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) if return_int: @@ -447,8 +447,8 @@ def rtn_quantize( ) if group_dim == 0: weight.transpose_(0, 1) - scale = scale.T if group_dim == 0 else scale - zp = zp.T if group_dim == 0 and zp is not None else zp + scale = scale.t_().contiguous() if group_dim == 0 else scale + zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp new_module = WeightOnlyLinear( m.in_features, m.out_features, @@ -651,18 +651,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1): if zp is not None: zp = zp.to(device) if group_size == -1: - return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp) + return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_() int_weight = torch.zeros(weight.shape).to(device) leng = weight.shape[1] // group_size tail_flag = False if weight.shape[1] % group_size == 0 else True for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1) + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) if zp is not None: - int_weight_tmp += zp[:, i].unsqueeze(1) - int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp) + int_weight_tmp.add_(zp[:, i].unsqueeze(1)) + int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) if tail_flag: - int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1) + int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) if zp is not None: - int_weight_tmp += zp[:, -1].unsqueeze(1) - int_weight[:, leng * group_size :] = torch.round(int_weight_tmp) + int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) + int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) return int_weight diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 44283e48fa6..72019b95aa0 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -180,7 +180,7 @@ def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig: self.local_config[operator_name] = config return self - def to_dict(self, params_list=[], operator2str=None): + def to_dict(self): result = {} global_config = self.get_params_dict() if bool(self.local_config): @@ -200,7 +200,7 @@ def get_params_dict(self): return result @classmethod - def from_dict(cls, config_dict, str2operator=None): + def from_dict(cls, config_dict): """Construct config from a dict. Args: diff --git a/neural_compressor/common/utility.py b/neural_compressor/common/utility.py index 7761a173d7d..42f6e445b9a 100644 --- a/neural_compressor/common/utility.py +++ b/neural_compressor/common/utility.py @@ -27,7 +27,7 @@ # config name BASE_CONFIG = "base_config" COMPOSABLE_CONFIG = "composable_config" -RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant" +RTN = "rtn" STATIC_QUANT = "static_quant" GPTQ = "gptq" FP8_QUANT = "fp8_quant" diff --git a/neural_compressor/tensorflow/utils.py b/neural_compressor/tensorflow/utils.py index 6f65f79fbc1..4497c1e9a7a 100644 --- a/neural_compressor/tensorflow/utils.py +++ b/neural_compressor/tensorflow/utils.py @@ -35,7 +35,7 @@ def register_algo(name): Usage example: @register_algo(name=example_algo) - def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: + def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module: ... Args: name (str): The name under which the algorithm function will be registered. diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index 57cfe472297..8989ae9d722 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -11,16 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from neural_compressor.torch.utils.utility import register_algo -from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry - -from neural_compressor.torch.quantization import ( - quantize, - RTNWeightQuantConfig, - get_default_rtn_config, - GPTQConfig, - get_default_gptq_config, -) - -from neural_compressor.torch.tune import autotune, TuningConfig, get_default_tune_config diff --git a/neural_compressor/torch/algorithms/__init__.py b/neural_compressor/torch/algorithms/__init__.py index ebb6e56ae35..707d36e9a4d 100644 --- a/neural_compressor/torch/algorithms/__init__.py +++ b/neural_compressor/torch/algorithms/__init__.py @@ -13,5 +13,7 @@ # limitations under the License. -from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry -from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry +from .weight_only import ( + rtn_quantize, + gptq_quantize, +) diff --git a/neural_compressor/torch/algorithms/weight_only/README.md b/neural_compressor/torch/algorithms/weight_only/README.md new file mode 100644 index 00000000000..d04d78a7d2b --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/README.md @@ -0,0 +1 @@ +# Demo of algorithm usage w/o INC diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index 8989ae9d722..ac8feca4f40 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .utility import * +from .rtn import rtn_quantize +from .gptq import gptq_quantize diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 8c624761c49..47c9f7ce607 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -15,8 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copied from neural_compressor/adaptor/torch_utils/gptq.py - import gc import math import random @@ -30,10 +28,7 @@ import transformers from tqdm import tqdm -from neural_compressor.common.logger import Logger - -logger = Logger().get_logger() - +from ...utils import logger DEBUG = False @@ -196,14 +191,12 @@ def __init__( self, model, weight_config={}, + dataloader=None, nsamples=128, - dataloader_len=10, use_max_length=True, pad_max_length=2048, device=None, layer_wise=False, - *args, - **kwargs, ): """ Args: @@ -233,7 +226,6 @@ def __init__( # weight config self.weight_config = weight_config # default settings, check configs - self.wdtype_default = "int" self.wbits_default = 4 self.group_size_default = 128 self.block_size_default = 128 @@ -242,10 +234,6 @@ def __init__( self.act_order_default = False self.perchannel_default = True self.mse_default = False - self.double_quant_dtype_default = "fp32" - self.double_quant_bits_default = 4 - self.double_quant_group_size_default = 128 - self.double_quant_sym_default = False self.check_layer_config() # device @@ -259,26 +247,148 @@ def __init__( # dataloader self.use_max_length = use_max_length self.pad_max_length = pad_max_length - self.dataloader_original = None + self.dataloader_original = dataloader self.dataloader = [] - self.dataloader_len = dataloader_len self.nsamples = nsamples - self.args = args - self.kwargs = kwargs - self.run_fn = self.kwargs.get("run_fn", None) - self.run_args = self.kwargs.get("run_args", None) - self.dataloader_len = dataloader_len - # compare 2.x, use run_fn to calibration - # self.prepare_dataloader() - self._post_init() - - def _post_init(self): - self.cache_key_arguments = { - "i": 0 - } # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) - # Note that the first elements in cache_positional_arguments is main input: hidden_states - self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) - self.is_ready = True + self.prepare_dataloader() + + def prepare_dataloader(self): + if self.use_max_length: + # (Recommend) only take sequence whose length exceeds self.pad_max_length, + # which preserves calibration's tokens are all valid + # This is GPTQ official dataloader implementation + self.obtain_first_n_samples_fulllength() + else: + # general selection, no padding, not GPTQ original implementation. + self.obtain_first_n_samples() + try: + self.cache_key_arguments = { + "i": 0 + } # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) + # Note that the first elements in cache_positional_arguments is main input: hidden_states + self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) + self.is_ready = True + except: + logger.warning("GPTQ Quantizer initialization failed!") + pass + + def obtain_first_n_samples(self, seed=0): + """Get first nsample data as the real calibration dataset.""" + self.dataloader.clear() + random.seed(seed) + for batch in self.dataloader_original: + # process data, depends on its data type. + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list, tuple + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] > self.pad_max_length: + i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1) + j = i + self.pad_max_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + batch_final = batch[:] + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length > self.pad_max_length: + i = random.randint(0, length - self.pad_max_length - 1) + j = i + self.pad_max_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim + else: + batch_final[key] = batch[key] + else: + batch_final = batch + # tensor + else: + if batch.shape[-1] > self.pad_max_length: + i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1) + j = i + self.pad_max_length + batch_final = batch[:, i:j] + else: + batch_final = batch + self.dataloader.append(batch_final) + + if len(self.dataloader) < self.nsamples: + logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") + + def obtain_first_n_samples_fulllength(self, seed=0): + self.dataloader.clear() + random.seed(seed) + unified_length = self.pad_max_length + for batch in self.dataloader_original: + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list & tuple, gpt-j-6b mlperf, etc. + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] == unified_length: + batch_final = batch[:] + elif batch[0].shape[-1] > unified_length: + i = random.randint(0, batch[0].shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + # not match max length, not include in target dataset + continue + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length == self.pad_max_length: + batch_final = batch + elif length > self.pad_max_length: + i = random.randint(0, length - self.pad_max_length - 1) + j = i + self.pad_max_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position + else: + batch_final[key] = batch[key] + else: + # not match max length, not include in target dataset + continue + # tensor + else: + if batch.shape[-1] == unified_length: + batch_final = batch + elif batch.shape[-1] > unified_length: + i = random.randint(0, batch.shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = batch[:, i:j] + else: + # not match max length, not include in target dataset + continue + self.dataloader.append(batch_final) + if len(self.dataloader) < self.nsamples: # pragma: no cover + logger.warning( + f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ + but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value." + ) def get_full_layer_name(self, sub_layer_name, block_idx): transformer_name = self.gptq_related_blocks["transformers_name"] @@ -290,7 +400,6 @@ def check_layer_config(self): tmp_weight_config = {} for name, module in self.model.named_modules(): tmp_weight_config[name] = {} - tmp_weight_config[name]["wdtype"] = self.weight_config.get("wdtype", self.wdtype_default) tmp_weight_config[name]["wbits"] = self.weight_config.get("wbits", self.wbits_default) tmp_weight_config[name]["group_size"] = self.weight_config.get("group_size", self.group_size_default) tmp_weight_config[name]["block_size"] = self.weight_config.get("block_size", self.group_size_default) @@ -299,22 +408,9 @@ def check_layer_config(self): tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default) tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default) tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default) - tmp_weight_config[name]["double_quant_dtype"] = self.weight_config.get( - "double_quant_dtype", self.double_quant_dtype_default - ) - tmp_weight_config[name]["double_quant_bits"] = self.weight_config.get( - "double_quant_bits", self.double_quant_bits_default - ) - tmp_weight_config[name]["double_quant_group_size"] = self.weight_config.get( - "double_quant_group_size", self.double_quant_group_size_default - ) - tmp_weight_config[name]["double_quant_sym"] = self.weight_config.get( - "double_quant_sym", self.double_quant_sym_default - ) self.weight_config = tmp_weight_config else: for layer_name, config in self.weight_config.items(): - self.weight_config[layer_name]["wdtype"] = config.get("wdtype", self.wdtype_default) self.weight_config[layer_name]["wbits"] = config.get("wbits", self.wbits_default) self.weight_config[layer_name]["group_size"] = config.get("group_size", self.group_size_default) self.weight_config[layer_name]["block_size"] = config.get("block_size", self.group_size_default) @@ -323,18 +419,6 @@ def check_layer_config(self): self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default) self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default) self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default) - self.weight_config[layer_name]["double_quant_dtype"] = config.get( - "double_quant_dtype", self.double_quant_dtype_default - ) - self.weight_config[layer_name]["double_quant_bits"] = config.get( - "double_quant_bits", self.double_quant_bits_default - ) - self.weight_config[layer_name]["double_quant_group_size"] = config.get( - "double_quant_group_size", self.double_quant_group_size_default - ) - self.weight_config[layer_name]["double_quant_sym"] = config.get( - "double_quant_sym", self.double_quant_sym_default - ) def get_layer_config(self, layer_name): """Obtain config for one layer, since GPTQ supports layer-wise config.""" @@ -400,24 +484,18 @@ def forward(layer, *args, **kwargs): # Step3: run forward to obtain calibration datasets logger.info("Collecting calibration inputs...") - logger.info("Collecting calibration inputs by running the run_fn provided by user.") - if self.run_args: - self.run_fn(self.model, self.run_args) - else: - self.run_fn(self.model) - - # for batch in tqdm(self.dataloader): - # if not self.layer_wise: - # batch = move_input_to_device(batch, self.device) - # try: - # if isinstance(batch, tuple) or isinstance(batch, list): - # self.model(batch[0]) - # elif isinstance(batch, dict): - # self.model(**batch) - # else: - # self.model(batch) - # except ValueError: - # pass + for batch in tqdm(self.dataloader): + if not self.layer_wise: + batch = move_input_to_device(batch, self.device) + try: + if isinstance(batch, tuple) or isinstance(batch, list): + self.model(batch[0]) + elif isinstance(batch, dict): + self.model(**batch) + else: + self.model(batch) + except ValueError: + pass # output inp data shape logger.info("All calibration data's shape =>") # check all hidden_states shape @@ -471,8 +549,11 @@ def execute_quantization(self, means=None, stds=None, model_path=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - # if we do not apply layer-wise feature, we still place the entire block on the GPU - transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + if not self.layer_wise: + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + else: + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -495,10 +576,21 @@ def execute_quantization(self, means=None, stds=None, model_path=None): # ) full_layer_name = self.get_full_layer_name(layer_name, block_idx) weight_config_this_layer = self.get_layer_config(full_layer_name) - W = sub_layers[layer_name].weight.data.clone() + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() - gptq_for_this_block[layer_name].quantizer.configure(weight_config_this_layer) + gptq_for_this_block[layer_name].quantizer.configure( + weight_config_this_layer["wbits"], + weight_config_this_layer["perchannel"], + weight_config_this_layer["sym"], + weight_config_this_layer["mse"], + ) # Step 2.3: modify forward functions to hook inputs data (used in gptq execution) def add_batch(_name): @@ -511,7 +603,7 @@ def tmp(_, inp, out): for layer_name in sub_layers: handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) idx = self.cache_key_arguments.pop("i") - for j in range(self.dataloader_len): + for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) out = transformer_block(*cache_positional_batch, **cache_keyword_batch) @@ -526,7 +618,13 @@ def tmp(_, inp, out): # ) weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) logger.info(f"Quantizing layer {layer_name}") - W = sub_layers[layer_name].weight.data.clone() + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( W, blocksize=weight_config_this_layer["block_size"], @@ -534,7 +632,30 @@ def tmp(_, inp, out): groupsize=weight_config_this_layer["group_size"], act_order=weight_config_this_layer["act_order"], ) - sub_layers[layer_name].weight.data = Q + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + + sub_layer = sub_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + "." + n + if n == "weight": + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") + clean_module_weight(sub_layer) + del Q + gc.collect() + else: + sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} if not weight_config_this_layer["sym"]: gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp @@ -547,14 +668,17 @@ def tmp(_, inp, out): # Step 2.5: replace output data with quantized weights outs = [] idx = self.cache_key_arguments.pop("i") - for j in range(self.dataloader_len): + for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) out = transformer_block(*cache_positional_batch, **cache_keyword_batch) out = self.track_hidden_states(out) outs.append(out) self.cache_key_arguments["i"] = idx - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + if self.layer_wise: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block torch.cuda.empty_cache() # iteratively replace the input with output, thus layerwise quantization can continue. @@ -682,9 +806,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) - q = self.quantizer.quantize( - w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq - ).flatten() + q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 @@ -727,6 +849,9 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F return scale, zero, Q def free(self): + if DEBUG: + self.inp1 = None + self.out1 = None self.H = None self.Losses = None self.Trace = None @@ -740,13 +865,11 @@ def __init__(self, shape=1): self.register_buffer("scale", torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape)) - def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, trits=False): - for k, v in weight_config_this_layer.items(): - setattr(self, k, v) - self.maxq = torch.tensor(2**self.wbits - 1) - self.scheme = "sym" if self.sym else "asym" - self.double_quant = self.double_quant_dtype != "fp32" - self.double_quant_scheme = "sym" if self.double_quant_sym else "asym" + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8, trits=False): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse self.norm = norm self.grid = grid self.maxshrink = maxshrink @@ -756,30 +879,7 @@ def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, def find_params(self, x, weight=False): dev = x.device self.maxq = self.maxq.to(dev) - # NF4 FP4 - if self.wdtype != "int": - from .rtn import quant_weight - - _, scale, zero = quant_weight( - x, - self.wbits, - self.group_size, - scheme=self.scheme, - data_type=self.wdtype, - quantile=1.0, - return_int=True, - full_range=False, - double_quant=self.double_quant, - double_quant_dtype=self.double_quant_dtype, - double_quant_bits=self.double_quant_bits, - double_quant_scheme=self.double_quant_scheme, - double_quant_group_size=self.double_quant_group_size, - double_quant_return_int=False, - ) - self.scale = scale - self.zero = torch.zeros_like(scale) - return - # INT + shape = x.shape if self.perchannel: if weight: @@ -826,7 +926,7 @@ def find_params(self, x, weight=False): xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self.quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q -= x q.abs_() q.pow_(self.norm) @@ -848,23 +948,6 @@ def find_params(self, x, weight=False): shape = [-1] + [1] * (len(shape) - 1) self.scale = self.scale.reshape(shape) self.zero = self.zero.reshape(shape) - - if self.double_quant: - from .rtn import quant_weight - - orig_scale_shape = self.scale.shape - self.scale = self.scale.reshape(1, -1) - self.scale = quant_weight( - self.scale, - self.double_quant_bits, - self.double_quant_group_size, - scheme=self.double_quant_scheme, - data_type=self.double_quant_dtype, - quantile=1.0, - return_int=False, - full_range=False, - ) - self.scale = self.scale.reshape(orig_scale_shape) return if len(shape) == 4: self.scale = self.scale.reshape((1, -1, 1, 1)) @@ -876,247 +959,39 @@ def find_params(self, x, weight=False): self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) - def quantize(self, x, scale, zero, maxq): - """Do quantization.""" - if self.wdtype != "int": - from .rtn import quantize_4bit + # def quantize(self, x): + # if self.ready(): + # return quantize(x, self.scale, self.zero, self.maxq) + # return x - return quantize_4bit(x, data_type=self.wdtype, scale=scale) - else: - if maxq < 0: - return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) - return scale * (q - zero) + # def enabled(self): + # return self.maxq > 0 def ready(self): return torch.all(self.scale != 0) -# TODO (Yi) remove it after unifying the algo config parser -from typing import Callable, Dict, Tuple - -from neural_compressor.torch.quantization.config import GPTQConfig - - -def gptq_config_mapping(configs_mapping: Dict[Tuple[str, Callable], GPTQConfig]): - # convert GPTQ_CONFIG to gptq_quantize's weight config - # convert tune_cfg to gptq_quantize's weight config - # for layer_wise quant mode - # TODO (Yi) uncomment it when port layer-wise - # if recipe_cfgs.get("layer_wise_quant", False): - # layer_wise = True - # from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks - - # os.makedirs(LWQ_WORKSPACE, exist_ok=True) - # # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - # model_path = model.path - # assert model_path, "model_path should not be None." - # model_path = _get_path(model_path) - # lwq_handles = register_weight_hooks( - # model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE - # ) - - weight_config = {} - for (op_name, op_type), op_config in configs_mapping.items(): - if op_config.weight_dtype == "fp32": - continue - else: - weight_config[op_name] = { - "wdtype": op_config.weight_dtype, - "wbits": op_config.weight_bits, - "group_size": op_config.weight_group_size, - "sym": op_config.weight_sym, - "percdamp": op_config.percdamp, - "act_order": op_config.act_order, - "block_size": op_config.block_size, - "mse": op_config.enable_mse_search, - "double_quant_dtype": op_config.double_quant_dtype, - "double_quant_bits": op_config.double_quant_bits, - "double_quant_group_size": op_config.double_quant_group_size, - "double_quant_sym": op_config.double_quant_sym, - } - nsamples = op_config.nsamples - dataloader_len = op_config.dataloader_len - use_max_length = op_config.use_max_length - pad_max_length = op_config.pad_max_length - device = op_config.device - - if use_max_length and op_config.pad_max_length == 2048: - logger.warning( - "You choose to use unified sequence length for calibration, \ - but you have not set length value. Default sequence length is 2048 and this might cause inference error!" - ) - - return weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len - - -def apply_gptq_quantize(model, configs_mapping, *args, **kwargs): - """Apply gptq.""" +def gptq_quantize( + model, + weight_config={}, + dataloader=None, + nsamples=128, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + model_path=None, +): + """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config - weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len = gptq_config_mapping( - configs_mapping - ) assert isinstance(model, torch.nn.Module), "only support torch module" - # TODO (Yi) disable layer-wise and model_path first - layer_wise = False - model_path = None + if layer_wise: + assert model_path is not None, "model_path should not be None when use layer_wise mode" + from .gptq import GPTQuantizer gptq_quantizer = GPTQuantizer( - model, - weight_config, - nsamples, - dataloader_len, - use_max_length, - pad_max_length, - device, - layer_wise=layer_wise, - *args, - **kwargs, + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise ) fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) - logger.info("GPTQ quantization done.") + logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config - - -class DataloaderPreprocessor: - def __init__(self, dataloader_original, use_max_length=False, pad_max_length=2048, nsamples=128) -> None: - self.dataloader_original = dataloader_original - self.use_max_length = use_max_length - self.pad_max_length = pad_max_length - self.nsamples = nsamples - self.dataloader = [] - self.is_ready = False - - def get_prepared_dataloader(self): - if not self.is_ready: - self.prepare_dataloader() - return self.dataloader - - def prepare_dataloader(self): - if self.use_max_length: - # (Recommend) only take sequence whose length exceeds self.pad_max_length, - # which preserves calibration's tokens are all valid - # This is GPTQ official dataloader implementation - self.obtain_first_n_samples_fulllength() - else: - # general selection, no padding, not GPTQ original implementation. - self.obtain_first_n_samples() - self.is_ready = True - - def obtain_first_n_samples(self, seed=0): - """Get first nsample data as the real calibration dataset.""" - self.dataloader.clear() - random.seed(seed) - for batch in self.dataloader_original: - # process data, depends on its data type. - if len(self.dataloader) == self.nsamples: - logger.info(f"Successfully collect {self.nsamples} calibration samples.") - break - # list, tuple - if isinstance(batch, list) or isinstance(batch, tuple): - if batch[0].shape[-1] > self.pad_max_length: - i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1) - j = i + self.pad_max_length - batch_final = [] - for item in batch: - if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: - batch_final.append(item[:, i:j]) - else: - batch_final.append(item) - else: - batch_final = batch[:] - # dict - elif isinstance(batch, dict): - try: - length = batch["input_ids"].shape[-1] - except: - logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") - continue - batch_final = {} - if length > self.pad_max_length: - i = random.randint(0, length - self.pad_max_length - 1) - j = i + self.pad_max_length - # may have to slice every sequence related data - for key in batch.keys(): - if isinstance(batch[key], torch.Tensor): - batch_final[key] = batch[key][:, i:j] # slice on sequence length dim - else: - batch_final[key] = batch[key] - else: - batch_final = batch - # tensor - else: - if batch.shape[-1] > self.pad_max_length: - i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1) - j = i + self.pad_max_length - batch_final = batch[:, i:j] - else: - batch_final = batch - self.dataloader.append(batch_final) - - if len(self.dataloader) < self.nsamples: - logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") - - def obtain_first_n_samples_fulllength(self, seed=0): - self.dataloader.clear() - random.seed(seed) - unified_length = self.pad_max_length - for batch in self.dataloader_original: - if len(self.dataloader) == self.nsamples: - logger.info(f"Successfully collect {self.nsamples} calibration samples.") - break - # list & tuple, gpt-j-6b mlperf, etc. - if isinstance(batch, list) or isinstance(batch, tuple): - if batch[0].shape[-1] == unified_length: - batch_final = batch[:] - elif batch[0].shape[-1] > unified_length: - i = random.randint(0, batch[0].shape[-1] - unified_length - 1) - j = i + unified_length - batch_final = [] - for item in batch: - if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: - batch_final.append(item[:, i:j]) - else: - batch_final.append(item) - else: - # not match max length, not include in target dataset - continue - # dict - elif isinstance(batch, dict): - try: - length = batch["input_ids"].shape[-1] - except: - logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") - continue - batch_final = {} - if length == self.pad_max_length: - batch_final = batch - elif length > self.pad_max_length: - i = random.randint(0, length - self.pad_max_length - 1) - j = i + self.pad_max_length - # may have to slice every sequence related data - for key in batch.keys(): - if isinstance(batch[key], torch.Tensor): - batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position - else: - batch_final[key] = batch[key] - else: - # not match max length, not include in target dataset - continue - # tensor - else: - if batch.shape[-1] == unified_length: - batch_final = batch - elif batch.shape[-1] > unified_length: - i = random.randint(0, batch.shape[-1] - unified_length - 1) - j = i + unified_length - batch_final = batch[:, i:j] - else: - # not match max length, not include in target dataset - continue - self.dataloader.append(batch_final) - if len(self.dataloader) < self.nsamples: # pragma: no cover - logger.warning( - f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ - but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value." - ) diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index 1f5949946c3..0b875882927 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -22,404 +22,34 @@ import torch from torch.nn import functional as F -from neural_compressor.common.logger import DEBUG, Logger, level -from neural_compressor.torch.utils.utility import set_module +from neural_compressor.torch.utils import logger, set_module -logger = Logger().get_logger() - - -NF4 = [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, -] -FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] -FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - -# the order is the same as float list, bit value range is [-7, 7] -# 1111 = -1, 1110 = -2, 1101= -3, ... - -NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] -FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] -FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] - -FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} -INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} - - -def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False, **kwargs): - """Quantize tensor to NF4/FP4 data type. - - Args: - tensor: input tensor - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): data type. Defaults to 'nf4'. - return_int (bool, optional): whether return int data. Defaults to False. - - Returns: - q_tensor: fake quantized tensor - """ - assert data_type in FLOAT_MAPPING, "unexpected data type." - allow_data = FLOAT_MAPPING[data_type] - allow_data_bit = INT_MAPPING[data_type] - # get scale and update tensor - if "scale" in kwargs: - scale = kwargs["scale"] - else: - scale = tensor.abs().max(1)[0] * quantile / max(allow_data) - scale.unsqueeze_(dim=-1) - tensor = tensor / scale - mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] - q_tensor = torch.zeros_like(tensor) - for i in range(len(allow_data)): - data = allow_data_bit[i] if return_int else allow_data[i] - if i == 0: - q_tensor += torch.where(tensor <= mid_data[i], data, 0) - elif i == len(allow_data) - 1: - q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) - else: - q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q_tensor, scale, None - return q_tensor * scale - - -def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False, **kwargs): - """Quant and dequant tensor with asym schema. - - Args: - weight: input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - - Returns: - output: qdq weight - """ - maxq = torch.tensor(2**num_bits - 1) - zeros = torch.zeros(weight.shape[0], device=weight.device) - wmin = torch.minimum(weight.min(1)[0], zeros) - wmax = torch.maximum(weight.max(1)[0], zeros) - wmin = wmin * quantile - wmax = wmax * quantile - tmp = (wmin == 0) & (wmax == 0) - wmin[tmp] = -1 - wmax[tmp] = +1 - scale = (wmax - wmin) / maxq - zp = torch.round(-wmin / scale) - scale.unsqueeze_(dim=-1) - zp.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q, scale, zp - return scale * (q - zp) - - -def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False, **kwargs): - """Quant and dequant tensor with sym schema. - - Args: - weight : input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - For example: 4 bit - scale = amax / 8 if full_range else amax / 7 - If True, scale = -scale if abs(min)> abs(max) else scale - Defaults to False. - - Returns: - output: qdq weight - """ - # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" - maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) - minq = torch.tensor(-(2 ** (num_bits - 1))).to(weight.device) - if num_bits == 1: # pragma: no cover - maxq = torch.tensor(2 ** (num_bits - 1)) - minq = torch.tensor(2 ** (num_bits - 1) - 1) - max_val = torch.max(weight, 1)[0] - min_val = torch.min(weight, 1)[0] - flip_flag = torch.abs(max_val) > torch.abs(min_val) - wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) - wmax = wmax * quantile - tmp = wmax == 0 - wmax[tmp] = +1 - if full_range: - # use -8, 8 to make sure amax is not changed after fake quant - scale = wmax / (-minq) - tmp = scale * flip_flag.int() - scale -= 2 * tmp # set negetive scale with flip_flag - else: - scale = wmax / maxq - scale.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale), minq, maxq) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q, scale, None - return scale * q - - -def qdq_weight_actor( - weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False, **kwargs -): - """Quant and dequant tensor per channel. - - Args: - weight : input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - output: qdq weight - """ - assert num_bits > 0, "num_bits should be larger than 0" - - if data_type in FLOAT_MAPPING.keys(): - return quantize_4bit(weight, quantile=quantile, data_type=data_type, return_int=return_int, **kwargs) - if scheme == "sym": - return qdq_weight_sym(weight, num_bits, quantile, return_int, full_range, **kwargs) - else: - return qdq_weight_asym(weight, num_bits, quantile, return_int, **kwargs) - - -def quant_weight( - weight, - num_bits=4, - group_size=-1, - scheme="asym", - quantile=1.0, - data_type="int", - return_int=False, - full_range=False, - **kwargs, -): - """Quant and dequant tensor with group size. - - Args: - weight: input weight - num_bits (int, optional): num_bits. Defaults to 4. - group_size (int, optional): how many elements share one scale/zp. Defaults to -1. - scheme (str, optional): sym or asym. Defaults to "asym". - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - output: qdq weight. - """ - double_quant = kwargs.get("double_quant", False) - if num_bits <= 0: # pragma: no cover - return weight - # case 1, group size = -1 - if group_size == -1 or weight.shape[1] < group_size: - group_size = weight.shape[1] - # case 2, reshape based on group size - orig_shape = weight.shape - if weight.shape[1] % group_size == 0: - weight = weight.reshape(-1, group_size) - weight = qdq_weight_actor( - weight, - num_bits, - scheme=scheme, - quantile=quantile, - return_int=return_int, - full_range=full_range, - data_type=data_type, - **kwargs, - ) - if return_int or double_quant: - weight, scale, zp = weight - weight = weight.reshape(orig_shape) - scale = scale.reshape(orig_shape[0], -1) - if zp is not None: - zp = zp.reshape(orig_shape[0], -1) - q_state = weight, scale, zp - else: - return weight.reshape(orig_shape) - else: - # case 3, process left part split by group size - split_index = weight.shape[1] // group_size * group_size - weight1 = weight[:, :split_index] - weight1 = weight1.reshape(-1, group_size) - weight1 = qdq_weight_actor( - weight1, - num_bits, - scheme=scheme, - quantile=quantile, - return_int=return_int, - full_range=full_range, - data_type=data_type, - **kwargs, - ) - if return_int or double_quant: - weight1, scale1, zp1 = weight1 - scale1 = scale1.reshape(orig_shape[0], -1) - if zp1 is not None: - zp1 = zp1.reshape(orig_shape[0], -1) - weight1 = weight1.reshape(orig_shape[0], split_index) - weight2 = weight[:, split_index:] - weight2 = qdq_weight_actor( - weight2, - num_bits, - scheme=scheme, - data_type=data_type, - quantile=quantile, - return_int=return_int, - full_range=full_range, - **kwargs, - ) - if return_int or double_quant: - weight2, scale2, zp2 = weight2 - weight = torch.cat([weight1, weight2], dim=1) - scale = torch.cat([scale1, scale2], dim=1) - zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1) - q_state = (weight, scale, zp) - else: - weight = torch.cat([weight1, weight2], dim=1) - return weight - if double_quant: - weight, scale, zp = q_state - double_quant_dtype = kwargs.get("double_quant_dtype", "fp32") - double_quant_num_bits = kwargs.get("double_quant_num_bits", 8) - double_quant_scheme = kwargs.get("double_quant_scheme", "sym") - double_quant_group_size = kwargs.get("double_quant_group_size", 256) - double_quant_return_int = kwargs.get("double_quant_return_int", return_int) - # process scale - orig_scale_shape = scale.shape - scale = scale.reshape(1, -1) - scale = quant_weight( - scale, - double_quant_num_bits, - double_quant_group_size, - scheme=double_quant_scheme, - quantile=1.0, - data_type=double_quant_dtype, - return_int=double_quant_return_int, - full_range=False, - double_quant=False, - ) - if return_int: - if double_quant_return_int: - scale, hyper_scale, hyper_zp = scale - scale = scale.reshape(orig_scale_shape) - return weight, (scale, hyper_scale, hyper_zp), zp - else: - scale = scale.reshape(orig_scale_shape) - return weight, scale, zp - else: - scale = scale.reshape(orig_scale_shape) - if weight.shape[1] % group_size != 0: - if zp is not None: - weight1 = weight1.reshape(-1, group_size) - zp[:, :-1].reshape(-1, 1) - weight2 = weight2 - zp[:, -1].reshape(-1, 1) - else: - weight1 = weight1.reshape(-1, group_size) - weight1 = weight1 * scale[:, :-1].reshape(-1, 1) - weight1 = weight1.reshape(orig_shape[0], -1) - weight2 = weight2 * scale[:, -1].reshape(-1, 1) - weight = torch.cat([weight1, weight2], dim=1) - else: - if zp is not None: - weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1) - else: - weight = weight.reshape(-1, group_size) - weight = weight * scale.reshape(-1, 1) - weight = weight.reshape(orig_shape[0], -1) - return weight - else: - return q_state - - -def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False): - """Search best clip range of each linears in current block. - - Args: - m (torch.nn.Module): torch module. - num_bits (int, optional): num bits. - group_size (int, optional): how many elements share one scale/zp. - scheme (str, optional): sym or asym. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - best_clip_ratio (float): best percentile of clip - """ - org_weight = m.weight.data - logger.info("Searching the best clip range with RTN algorithm") - best_error = float("inf") - best_clip_ratio = None - n_grid = 200 - max_shrink = 0.2 - history = [] - for i_s in range(int(max_shrink * n_grid)): - ratio = 1 - i_s / n_grid # 1, 0.805-1.0 - cur_weight = quant_weight( - m.weight.data, - num_bits=num_bits, - group_size=group_size, - scheme=scheme, - data_type=data_type, - full_range=enable_full_range, - quantile=ratio, - ) - loss = (org_weight - cur_weight).float().pow(2).mean().item() - history.append(loss) - is_best = loss < best_error - if is_best: - best_error = loss - best_clip_ratio = ratio - logger.debug("The loss history of different clip range:{}".format(history)) - logger.debug("The best clip ratio is {}".format(best_clip_ratio)) - return best_clip_ratio +from .utility import quant_tensor, search_clip def rtn_quantize( model, - num_bits=4, + dtype="int", + bits=4, + scheme="sym", group_size=32, - scheme="asym", + group_dim=1, quantile=1.0, weight_config={}, - return_int=False, - data_type="int", - enable_full_range=False, - enable_mse_search=False, - group_dim=1, + export_compressed_model=False, + use_full_range=False, + use_mse_search=False, **kwargs, ): - """Quant the model with round to nearst method. + """Quant the model with round to nearest method and inplace is True. Args: model: torch module - num_bits: num bits. Defaults to 4. + bits: num bits. Defaults to 4. group_size (int, optional): how many elements share one scale/zp. Defaults to 32. - scheme (str, optional): sym or asym. Defaults to "asym". + scheme (str, optional): sym or asym. Defaults to "sym". quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. weight_config (dict, optional): specific layer wise configurations. Defaults to {}. For example, weight_config={ @@ -429,14 +59,13 @@ def rtn_quantize( 'bits': 4, 'group_size': 32, 'scheme': 'sym' - 'gptq_perm': [1, 1, ...] # for gptq perm } } - return_int (bool, optional): Choose return fp32 or int32 model. + export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False. - enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). + use_full_range (bool, optional): Choose sym range whether use -2**(bits-1). Defaults to False. - enable_mse_search (bool, optional): Whether search clip range. + use_mse_search (bool, optional): Whether search clip range. Defaults to True. group_dim (int, optional): 0 means splitting output channel, 1 means splitting input channel. Defaults to 1. @@ -444,73 +73,85 @@ def rtn_quantize( Returns: model: fake quantized torch module """ + device = "cpu" assert isinstance(model, torch.nn.Module), "only support torch module" supported_layers = ["Linear"] + # initialize global configuration double_quant_dtype = kwargs.get("double_quant_dtype", "fp32") double_quant_config = { "double_quant": False if double_quant_dtype == "fp32" else True, "double_quant_dtype": double_quant_dtype, - "double_quant_num_bits": kwargs.get("double_quant_num_bits", 8), + "double_quant_bits": kwargs.get("double_quant_bits", 8), "double_quant_scheme": kwargs.get("double_quant_scheme", "sym"), "double_quant_group_size": kwargs.get("double_quant_group_size", 256), } - if return_int: - compression_dtype = kwargs.get("compression_dtype", torch.int32) - compression_dim = kwargs.get("compression_dim", 1) - scale_dtype = kwargs.get("scale_dtype", torch.float32) - device = kwargs.get("device", "cpu") + if export_compressed_model: + use_optimum_format = kwargs.get("use_optimum_format", True) for name, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: continue if name in weight_config: # pragma: no cover - data_type = weight_config[name].get("dtype", "int") - num_bits = weight_config[name]["bits"] + # initialize op configuration + dtype = weight_config[name].get("dtype", "int") + bits = weight_config[name].get("bits", 4) group_size = weight_config[name]["group_size"] scheme = weight_config[name]["scheme"] quantile = weight_config[name].get("quantile", 1.0) + group_dim = weight_config[name]["group_dim"] + use_full_range = weight_config[name]["use_full_range"] + use_mse_search = weight_config[name]["use_mse_search"] + use_layer_wise = weight_config[name]["use_layer_wise"] + export_compressed_model = weight_config[name]["export_compressed_model"] + if export_compressed_model: + use_optimum_format = kwargs.get("use_optimum_format", True) + double_quant_dtype = weight_config[name]["double_quant_dtype"] + double_quant_config = { + "double_quant": False if double_quant_dtype == "fp32" else True, + "double_quant_dtype": double_quant_dtype, + "double_quant_bits": weight_config[name]["double_quant_bits"], + "double_quant_scheme": weight_config[name]["double_quant_scheme"], + "double_quant_group_size": weight_config[name]["double_quant_group_size"], + } log_msg = ( - f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " - + f"scheme={scheme}, quantile={quantile}" + f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}" ) - if data_type != "int": - log_msg += f", dtype={data_type}" + if dtype != "int": + log_msg += f", dtype={dtype}" elif scheme == "sym": # nf4/fp4 is always [-7,7] - log_msg += f", enable_full_range={enable_full_range}" - if data_type == "fp32": + log_msg += f", use_full_range={use_full_range}" + if dtype == "fp32": continue logger.debug(f"RTN quantized module:{name, m}") logger.debug(log_msg) - weight = m.weight.T if group_dim == 0 else m.weight - if enable_mse_search: - quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) - if return_int: - int_weight, scale, zp = quant_weight( + weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight + if use_mse_search: + quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range) + if export_compressed_model: + int_weight, scale, zp = quant_tensor( weight, - num_bits, + dtype, + bits, group_size, scheme, quantile, - data_type=data_type, return_int=True, - full_range=enable_full_range, + full_range=use_full_range, **double_quant_config, ) - int_weight = int_weight.T if group_dim == 0 else int_weight - scale = scale.T if group_dim == 0 else scale - zp = zp.T if group_dim == 0 and zp is not None else zp + int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight + scale = scale.t_().contiguous() if group_dim == 0 else scale + zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp from neural_compressor.torch.quantization.layers import WeightOnlyLinear new_module = WeightOnlyLinear( m.in_features, m.out_features, - num_bits, + bits, group_size, - dtype=data_type, + dtype=dtype, zp=zp is not None, bias=m.bias is not None, - compression_dtype=compression_dtype, - compression_dim=compression_dim, - scale_dtype=scale_dtype, + use_optimum_format=use_optimum_format, device=device, ) new_module.pack(int_weight, scale, zp, m.bias) @@ -519,96 +160,16 @@ def rtn_quantize( else: set_module(model, name, new_module) else: - q_weight = quant_weight( + weight = quant_tensor( weight, - num_bits, + dtype, + bits, group_size, scheme, quantile, - data_type=data_type, - full_range=enable_full_range, + full_range=use_full_range, **double_quant_config, ) - q_weight = q_weight.T if group_dim == 0 else q_weight - m.weight.data.copy_(q_weight) + weight = weight.t_().contiguous() if group_dim == 0 else weight + m.weight.data.copy_(weight) return model - - -def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): - """Quant and dequant tensor with group size. - - Args: - weight: input weight - scale: scale - zp: zero point - group_size (int, optional): how many elements share one scale/zp. Defaults to -1. - dtype: data type, for NF4 FP4 - - Returns: - output: int weight. - """ - device = weight.device - scale = scale.to(device) - # NF4 FP4 - if dtype in FLOAT_MAPPING.keys(): - int_weight = quantize_4bit( - weight, - quantile=1.0, - data_type=dtype, - return_int=True, - scale=scale, - )[0] - return int_weight - # INT - if zp is not None: - zp = zp.to(device) - if group_size == -1: - return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp) - int_weight = torch.zeros(weight.shape).to(device) - leng = weight.shape[1] // group_size - tail_flag = False if weight.shape[1] % group_size == 0 else True - for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1) - if zp is not None: - int_weight_tmp += zp[:, i].unsqueeze(1) - int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp) - if tail_flag: - int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1) - if zp is not None: - int_weight_tmp += zp[:, -1].unsqueeze(1) - int_weight[:, leng * group_size :] = torch.round(int_weight_tmp) - return int_weight - - -from neural_compressor.torch.quantization.config import RTNWeightQuantConfig - - -def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: - # TODO (Yi) remove it - enable_full_range = quant_config.enable_full_range - enable_mse_search = quant_config.enable_mse_search - group_dim = quant_config.group_dim - dtype = quant_config.weight_dtype - num_bits = quant_config.weight_bits - scheme = "sym" if quant_config.weight_sym else "asym" - group_size = quant_config.weight_group_size - return_int = quant_config.return_int - double_quant_dtype = quant_config.double_quant_dtype - double_quant_num_bits = quant_config.double_quant_bits - double_quant_scheme = "sym" if quant_config.double_quant_sym else "asym" - double_quant_group_size = quant_config.double_quant_group_size - return rtn_quantize( - module, - num_bits, - group_size, - scheme, - return_int=return_int, - data_type=dtype, - enable_full_range=enable_full_range, - enable_mse_search=enable_mse_search, - group_dim=group_dim, - double_quant_dtype=double_quant_dtype, - double_quant_scheme=double_quant_scheme, - double_quant_num_bits=double_quant_num_bits, - double_quant_group_size=double_quant_group_size, - ) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py new file mode 100644 index 00000000000..66ac2cad8ec --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -0,0 +1,440 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F + +from neural_compressor.common.logger import DEBUG, Logger, level +from neural_compressor.torch.utils.utility import set_module + +logger = Logger().get_logger() + + +NF4 = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, +] +FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] +FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + +# the order is the same as float list, bit value range is [-7, 7] +# 1111 = -1, 1110 = -2, 1101= -3, ... + +NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] +FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] +FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] + +FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} +INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} + + +def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs): + """Quantize tensor to NF4/FP4 data type. + + Args: + tensor: input tensor + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): data type. Defaults to 'nf4'. + return_int (bool, optional): whether return int data. Defaults to False. + + Returns: + q_tensor: fake quantized tensor + """ + assert dtype in FLOAT_MAPPING, "unexpected data type." + allow_data = FLOAT_MAPPING[dtype] + allow_data_bit = INT_MAPPING[dtype] + # get scale and update tensor + if "scale" in kwargs: + scale = kwargs["scale"] + else: + scale = tensor.abs().max(1)[0] * quantile / max(allow_data) + scale.unsqueeze_(dim=-1) + tensor.div_(scale) + mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] + q_tensor = torch.zeros_like(tensor) + for i in range(len(allow_data)): + data = allow_data_bit[i] if return_int else allow_data[i] + if i == 0: + q_tensor += torch.where(tensor <= mid_data[i], data, 0) + elif i == len(allow_data) - 1: + q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) + else: + q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) + tensor.copy_(q_tensor) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return tensor, scale, None + return tensor.mul_(scale) + + +def qdq_weight_asym(weight, bits=4, quantile=1.0, return_int=False, **kwargs): + """Quant and dequant tensor with asym schema. + + Args: + weight: input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + + Returns: + output: qdq weight + """ + maxq = torch.tensor(2**bits - 1) + zeros = torch.zeros(weight.shape[0], device=weight.device) + wmin = torch.minimum(weight.min(1)[0], zeros) + wmax = torch.maximum(weight.max(1)[0], zeros) + wmin = wmin * quantile + wmax = wmax * quantile + tmp = (wmin == 0) & (wmax == 0) + wmin[tmp] = -1 + wmax[tmp] = +1 + scale = (wmax - wmin) / maxq + zp = torch.round(-wmin / scale) + scale.unsqueeze_(dim=-1) + zp.unsqueeze_(dim=-1) + weight.div_(scale) + weight.round_() + weight.clamp_(0, maxq) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return weight, scale, zp + weight.sub_(zp) + return weight.mul_(scale) + + +def qdq_weight_sym(weight, bits=4, quantile=1.0, return_int=False, full_range=False, **kwargs): + """Quant and dequant tensor with sym schema. + + Args: + weight : input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + For example: 4 bit + scale = amax / 8 if full_range else amax / 7 + If True, scale = -scale if abs(min)> abs(max) else scale + Defaults to False. + + Returns: + output: qdq weight + """ + # assert bits > 1, "symmetric scheme only supports bits > 1" + maxq = torch.tensor(2 ** (bits - 1) - 1).to(weight.device) + minq = torch.tensor(-(2 ** (bits - 1))).to(weight.device) + if bits == 1: # pragma: no cover + maxq = torch.tensor(2 ** (bits - 1)) + minq = torch.tensor(2 ** (bits - 1) - 1) + max_val = torch.max(weight, 1)[0] + min_val = torch.min(weight, 1)[0] + flip_flag = torch.abs(max_val) > torch.abs(min_val) + wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) + wmax = wmax * quantile + tmp = wmax == 0 + wmax[tmp] = +1 + if full_range: + # use -8, 8 to make sure amax is not changed after fake quant + scale = wmax / (-minq) + tmp = scale * flip_flag.int() + scale -= 2 * tmp # set negetive scale with flip_flag + else: + scale = wmax / maxq + scale.unsqueeze_(dim=-1) + weight.div_(scale) + weight.round_() + weight.clamp_(minq, maxq) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return weight, scale, None + return weight.mul_(scale) + + +def qdq_weight_actor(weight, bits, scheme, quantile=1.0, dtype="int", return_int=False, full_range=False, **kwargs): + """Quant and dequant tensor per channel. It is an in-place op. + + Args: + weight : input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight + """ + assert bits > 0, "bits should be larger than 0" + + if dtype in FLOAT_MAPPING.keys(): + return quantize_4bit(weight, quantile=quantile, dtype=dtype, return_int=return_int, **kwargs) + if scheme == "sym": + return qdq_weight_sym(weight, bits, quantile, return_int, full_range, **kwargs) + else: + return qdq_weight_asym(weight, bits, quantile, return_int, **kwargs) + + +def quant_tensor( + weight, + dtype="int", + bits=4, + group_size=-1, + scheme="asym", + quantile=1.0, + return_int=False, + full_range=False, + **kwargs, +): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + bits (int, optional): bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight. + """ + double_quant = kwargs.get("double_quant", False) + if bits <= 0: # pragma: no cover + return weight + # case 1, group size = -1 + if group_size == -1 or weight.shape[1] < group_size: + group_size = weight.shape[1] + # case 2, reshape based on group size + orig_shape = weight.shape + if weight.shape[1] % group_size == 0: + weight = weight.reshape(-1, group_size) + weight = qdq_weight_actor( + weight, + bits, + scheme=scheme, + quantile=quantile, + return_int=return_int, + full_range=full_range, + dtype=dtype, + **kwargs, + ) + if return_int or double_quant: + weight, scale, zp = weight + weight = weight.reshape(orig_shape) + scale = scale.reshape(orig_shape[0], -1) + if zp is not None: + zp = zp.reshape(orig_shape[0], -1) + q_state = weight, scale, zp + else: + return weight.reshape(orig_shape) + else: + # case 3, process left part split by group size + split_index = weight.shape[1] // group_size * group_size + weight1 = weight[:, :split_index] + weight1 = weight1.reshape(-1, group_size) + weight1 = qdq_weight_actor( + weight1, + bits, + scheme=scheme, + quantile=quantile, + return_int=return_int, + full_range=full_range, + dtype=dtype, + **kwargs, + ) + if return_int or double_quant: + weight1, scale1, zp1 = weight1 + scale1 = scale1.reshape(orig_shape[0], -1) + if zp1 is not None: + zp1 = zp1.reshape(orig_shape[0], -1) + weight1 = weight1.reshape(orig_shape[0], split_index) + weight2 = weight[:, split_index:] + weight2 = qdq_weight_actor( + weight2, + bits, + scheme=scheme, + dtype=dtype, + quantile=quantile, + return_int=return_int, + full_range=full_range, + **kwargs, + ) + if return_int or double_quant: + weight2, scale2, zp2 = weight2 + weight.copy_(torch.cat([weight1, weight2], dim=1)) + scale.copy_(torch.cat([scale1, scale2], dim=1)) + zp = None if zp2 is None else zp.copy_(torch.cat([zp1, zp2], dim=1)) + q_state = (weight, scale, zp) + else: + weight.copy_(torch.cat([weight1, weight2], dim=1)) + return weight + if double_quant: + weight, scale, zp = q_state + double_quant_dtype = kwargs.get("double_quant_dtype", "fp32") + double_quant_bits = kwargs.get("double_quant_bits", 8) + double_quant_scheme = kwargs.get("double_quant_scheme", "sym") + double_quant_group_size = kwargs.get("double_quant_group_size", 256) + double_quant_return_int = kwargs.get("double_quant_return_int", return_int) + # process scale + orig_scale_shape = scale.shape + scale = scale.reshape(1, -1) + scale = quant_tensor( + scale, + double_quant_bits, + double_quant_group_size, + scheme=double_quant_scheme, + quantile=1.0, + dtype=double_quant_dtype, + return_int=double_quant_return_int, + full_range=False, + double_quant=False, + ) + if return_int: + if double_quant_return_int: + scale, hyper_scale, hyper_zp = scale + scale = scale.reshape(orig_scale_shape) + return weight, (scale, hyper_scale, hyper_zp), zp + else: + scale = scale.reshape(orig_scale_shape) + return weight, scale, zp + else: + scale = scale.reshape(orig_scale_shape) + if weight.shape[1] % group_size != 0: + if zp is not None: + weight1 = weight1.reshape(-1, group_size).sub_(zp[:, :-1].reshape(-1, 1)) + weight2 = weight2.sub_(zp[:, -1].reshape(-1, 1)) + else: + weight1 = weight1.reshape(-1, group_size) + weight1 = weight1.mul_(scale[:, :-1].reshape(-1, 1)) + weight1 = weight1.reshape(orig_shape[0], -1) + weight2 = weight2.mul_(scale[:, -1].reshape(-1, 1)) + weight = torch.cat([weight1, weight2], dim=1) + else: + if zp is not None: + weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1) + else: + weight = weight.reshape(-1, group_size) + weight = weight.mul_(scale.reshape(-1, 1)) + weight = weight.reshape(orig_shape[0], -1) + return weight + else: + return q_state + + +def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_full_range=False): + """Search best clip range of each linear in current block. + + Args: + m (torch.nn.Module): torch module. + bits (int, optional): num bits. + group_size (int, optional): how many elements share one scale/zp. + scheme (str, optional): sym or asym. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + best_clip_ratio (float): best percentile of clip + """ + org_weight = m.weight.data.clone() + logger.info("Searching the best clip range with RTN algorithm") + best_error = float("inf") + best_clip_ratio = None + n_grid = 200 + max_shrink = 0.2 + history = [] + for i_s in range(int(max_shrink * n_grid)): + ratio = 1 - i_s / n_grid # 1, 0.805-1.0 + cur_weight = quant_tensor( + m.weight.data, + bits=bits, + group_size=group_size, + scheme=scheme, + dtype=dtype, + full_range=enable_full_range, + quantile=ratio, + ) + loss = (org_weight - cur_weight).float().pow(2).mean().item() + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_clip_ratio = ratio + logger.debug("The loss history of different clip range:{}".format(history)) + logger.debug("The best clip ratio is {}".format(best_clip_ratio)) + return best_clip_ratio + + +def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + scale: scale + zp: zero point + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + dtype: data type, for NF4 FP4 + + Returns: + output: int weight. + """ + device = weight.device + scale = scale.to(device) + # NF4 FP4 + if dtype in FLOAT_MAPPING.keys(): + int_weight = quantize_4bit( + weight, + quantile=1.0, + dtype=dtype, + return_int=True, + scale=scale, + )[0] + return int_weight + # INT + if zp is not None: + zp = zp.to(device) + if group_size == -1: + return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_() + int_weight = torch.zeros(weight.shape).to(device) + leng = weight.shape[1] // group_size + tail_flag = False if weight.shape[1] % group_size == 0 else True + for i in range(leng): + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, i].unsqueeze(1)) + int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) + if tail_flag: + int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) + int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) + return int_weight diff --git a/neural_compressor/torch/algorithms/weight_only_algos.py b/neural_compressor/torch/algorithms/weight_only_algos.py deleted file mode 100644 index e3cef82d213..00000000000 --- a/neural_compressor/torch/algorithms/weight_only_algos.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Dict, Tuple - -import torch - -from neural_compressor.common.logger import Logger -from neural_compressor.common.utility import GPTQ, RTN_WEIGHT_ONLY_QUANT -from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig -from neural_compressor.torch.utils.utility import fetch_module, register_algo, set_module - -logger = Logger().get_logger() - - -###################### RTN Algo Entry ################################## -@register_algo(name=RTN_WEIGHT_ONLY_QUANT) -def rtn_quantize_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNWeightQuantConfig], *args, **kwargs -) -> torch.nn.Module: - """The main entry to apply rtn quantization.""" - from .weight_only.rtn import apply_rtn_on_single_module - - for (op_name, op_type), quant_config in configs_mapping.items(): - original_module = fetch_module(model, op_name) - if original_module is None: - continue - logger.info(f"Apply RTN on module: {op_name}, {original_module}") - rtn_module = apply_rtn_on_single_module(original_module, quant_config) - set_module(model, op_name, rtn_module) - return model - - -###################### GPTQ Algo Entry ################################## -@register_algo(name=GPTQ) -def gptq_quantize_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs -) -> torch.nn.Module: - logger.info("Quantize model with the GPTQ algorithm.") - from .weight_only.gptq import apply_gptq_quantize - - model, quantization_perm = apply_gptq_quantize(model=model, configs_mapping=configs_mapping, *args, **kwargs) - # Assign the gptq config as an attribute of model - model._gptq_quantization_perm = quantization_perm - return model diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index d54393a21fc..2b8cbd68ca8 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -12,10 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neural_compressor.torch.quantization.quantize import quantize, quantize_dynamic -from neural_compressor.torch.quantization.config import ( - RTNWeightQuantConfig, +from .config import ( get_default_rtn_config, - GPTQConfig, get_default_gptq_config, + RTNConfig, + GPTQConfig, ) +from .quantize import quantize, quantize_dynamic + +# TODO(Yi): move config to config.py +from .autotune import autotune, TuningConfig, get_default_tune_config + +### Quantization Function Registration ### +import neural_compressor.torch.quantization.weight_only +from neural_compressor.torch.utils import is_hpex_available + +if is_hpex_available(): + import neural_compressor.torch.quantization.fp8 diff --git a/neural_compressor/torch/tune.py b/neural_compressor/torch/quantization/autotune.py similarity index 97% rename from neural_compressor/torch/tune.py rename to neural_compressor/torch/quantization/autotune.py index 656d6c5b1be..3c734309a3c 100644 --- a/neural_compressor/torch/tune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -18,7 +18,7 @@ from neural_compressor.common.base_tune import BaseTuningConfig, FrameworkWrapper, Tuner, tuning_objectives from neural_compressor.common.logger import Logger -from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig +from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig logger = Logger().get_logger() @@ -68,4 +68,4 @@ def autotune( def get_default_tune_config(): # TODO use the registered default tuning config in the next PR - return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNWeightQuantConfig(weight_bits=[4, 8])]) + return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])]) diff --git a/neural_compressor/torch/quantization/layers.py b/neural_compressor/torch/quantization/layers.py index 9d279ee11b7..b50383d6319 100644 --- a/neural_compressor/torch/quantization/layers.py +++ b/neural_compressor/torch/quantization/layers.py @@ -25,9 +25,8 @@ from torch.autograd import Function from torch.nn import functional as F -from neural_compressor.common import logger -from neural_compressor.common.logger import DEBUG, level -from neural_compressor.torch.algorithms.weight_only.rtn import quant_weight +from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor +from neural_compressor.torch.utils import logger def get_torch_version(): @@ -165,7 +164,7 @@ def __init__( self.use_optimum_format = use_optimum_format self.dtype = dtype if "int" not in self.dtype: # for nf4, fp4 - from neural_compressor.torch.algorithms.weight_only.rtn import FLOAT_MAPPING, INT_MAPPING + from neural_compressor.torch.algorithms.weight_only import FLOAT_MAPPING, INT_MAPPING float_list = FLOAT_MAPPING[self.dtype] int_list = INT_MAPPING[self.dtype] @@ -201,7 +200,6 @@ def __init__( dtype=self.float_type, ).to(device), ) - self.scales = self.scales.T self.register_buffer( "qweight", torch.zeros( @@ -209,7 +207,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qweight = self.qweight.T self.register_buffer( "qzeros", torch.zeros( @@ -217,7 +214,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qzeros = self.qzeros.T self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: self.compression_dtype = compression_dtype @@ -271,6 +267,10 @@ def __init__( self.g_idx = None def pack(self, int_weight, scale, zp, bias, g_idx=None): + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -291,8 +291,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -308,15 +308,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -328,16 +328,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -352,8 +352,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -368,7 +368,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -378,10 +378,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -395,7 +395,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 @@ -410,12 +410,13 @@ def recover(self): return fp32_weight def forward(self, input): - weight = self.recover() - device = self.scales.device - if weight.dtype == torch.float16 and device.type == "cpu": - weight = weight.float() - self.bias = self.bias.float() if self.bias is not None else None - if level == DEBUG: + if not hasattr(self, "weight"): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None + if True: # keep reusing self.weight due to recover is too slow. if not hasattr(self, "weight"): self.weight = weight input = input.type(self.weight.dtype) @@ -455,7 +456,7 @@ def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"): Returns: outputs: A Tensor of type output_dtype """ - return quant_weight(inputs, num_bits, group_size, scheme) + return quant_tensor(inputs, num_bits, group_size, scheme) @staticmethod def backward(ctx, grad_outputs): diff --git a/neural_compressor/torch/quantization/weight_only/__init__.py b/neural_compressor/torch/quantization/weight_only/__init__.py new file mode 100644 index 00000000000..5be35a73e0d --- /dev/null +++ b/neural_compressor/torch/quantization/weight_only/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantization_impl import ( + rtn_entry, + gptq_entry, +) diff --git a/neural_compressor/torch/quantization/weight_only/quantization_impl.py b/neural_compressor/torch/quantization/weight_only/quantization_impl.py new file mode 100644 index 00000000000..f6797e39252 --- /dev/null +++ b/neural_compressor/torch/quantization/weight_only/quantization_impl.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import torch + +from neural_compressor.common.utility import GPTQ, RTN # unified namespace +from neural_compressor.torch.algorithms import gptq_quantize, rtn_quantize +from neural_compressor.torch.quantization import GPTQConfig, RTNConfig +from neural_compressor.torch.utils import fetch_module, logger, register_algo + + +###################### RTN Algo Entry ################################## +@register_algo(name=RTN) +def rtn_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs +) -> torch.nn.Module: + """The main entry to apply rtn quantization.""" + # rebuild weight_config for rtn_quantize function + weight_config = {} + for (op_name, op_type), quant_config in configs_mapping.items(): + weight_config[op_name] = { + "dtype": quant_config.dtype, + "bits": quant_config.bits, + "scheme": "sym" if quant_config.use_sym else "asym", + "group_size": quant_config.group_size, + "group_dim": quant_config.group_dim, + "use_full_range": quant_config.use_full_range, + "use_mse_search": quant_config.use_mse_search, + "use_layer_wise": quant_config.use_layer_wise, + "export_compressed_model": quant_config.export_compressed_model, + "double_quant_dtype": quant_config.double_quant_dtype, + "double_quant_bits": quant_config.double_quant_bits, + "double_quant_scheme": "sym" if quant_config.double_quant_sym else "asym", + "double_quant_group_size": quant_config.double_quant_group_size, + } + + model = rtn_quantize(model, weight_config=weight_config) + return model + + +###################### GPTQ Algo Entry ################################## +@register_algo(name=GPTQ) +def gptq_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs +) -> torch.nn.Module: + logger.info("Quantize model with the GPTQ algorithm.") + from .weight_only.gptq import apply_gptq_quantize + + model, quantization_perm = gptq_quantize(model=model, configs_mapping=configs_mapping, *args, **kwargs) + # Assign the gptq config as an attribute of model + model._gptq_quantization_perm = quantization_perm + return model diff --git a/neural_compressor/torch/utils/__init__.py b/neural_compressor/torch/utils/__init__.py index 8989ae9d722..2c8a6c4704d 100644 --- a/neural_compressor/torch/utils/__init__.py +++ b/neural_compressor/torch/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .environ import * +from .constants import * +from .utility import * diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 74990a5bf42..a3d6109179c 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -26,7 +26,7 @@ "double_quant_sym": True, "double_quant_group_size": 8, }, - "BNB": { + "BNB_NF4": { "weight_dtype": "nf4", "weight_bits": 4, "weight_group_size": 32, diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py new file mode 100644 index 00000000000..3a5b1c1373c --- /dev/null +++ b/neural_compressor/torch/utils/environ.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# pylint:disable=import-error +try: + import deepspeed + import habana_frameworks.torch.hpex + + _hpex_available = True +except: + _hpex_available = False + + +def is_hpex_available(): + return _hpex_available + + +try: + import intel_extension_for_pytorch as ipex + + _ipex_available = True +except: + _ipex_available = False + + +def is_ipex_available(): + return _ipex_available diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index b1748d059eb..45e36d83911 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -15,14 +15,15 @@ from typing import Callable, Dict, List, Tuple -from neural_compressor.common.logger import Logger +import torch -logger = Logger().get_logger() +from neural_compressor.common.logger import Logger # Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) algos_mapping: Dict[str, Callable] = {} -import torch + +logger = Logger().get_logger() # All constants for torch WHITE_MODULE_LIST = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] @@ -33,7 +34,7 @@ def register_algo(name): Usage example: @register_algo(name=example_algo) - def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: + def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module: ... Args: @@ -119,17 +120,3 @@ def get_double_quant_config(double_quant_type, weight_sym=True): ) DOUBLE_QUANT_CONFIGS[double_quant_type]["weight_sym"] = weight_sym return DOUBLE_QUANT_CONFIGS[double_quant_type] - - -# pylint:disable=import-error -try: - import deepspeed - import habana_frameworks.torch.hpex - - _hpex_avalible = True -except: - _hpex_avalible = False - - -def is_hpex_avaliable(): - return _hpex_avalible diff --git a/test/3x/torch/algorithms/weight_only/test_rtn.py b/test/3x/torch/algorithms/weight_only/test_rtn.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/3x/torch/quantization/weight_only/test_gptq_algo.py b/test/3x/torch/quantization/weight_only/test_gptq.py similarity index 100% rename from test/3x/torch/quantization/weight_only/test_gptq_algo.py rename to test/3x/torch/quantization/weight_only/test_gptq.py diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 00159b55828..c2671d6d95a 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -1,135 +1,27 @@ +import copy import unittest -import torch +import transformers -from neural_compressor.common.logger import Logger - -logger = Logger().get_logger() - - -def get_gpt_j(): - import transformers - - tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-GPTJForCausalLM", - torchscript=True, - ) - return tiny_gptj - - -def build_simple_torch_model(): - class Model(torch.nn.Module): - def __init__(self): - super(Model, self).__init__() - self.fc1 = torch.nn.Linear(8, 30) - self.fc2 = torch.nn.Linear(30, 60) - self.fc3 = torch.nn.Linear(60, 30) - self.fc4 = torch.nn.Linear(30, 50) - - def forward(self, x): - out = self.fc1(x) - out = self.fc2(out) - out = self.fc3(out) - out = self.fc4(out) - return out - - model = Model() - return model +from neural_compressor.torch.quantization import RTNConfig, get_default_rtn_config, quantize class TestRTNQuant(unittest.TestCase): @classmethod def setUpClass(self): - pass + self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) @classmethod def tearDownClass(self): pass - def setUp(self): - # print the test name - logger.info(f"Running TestRTNQuant test: {self.id()}") - - def _apply_rtn(self, quant_config): - logger.info(f"Test RTN with config {quant_config}") - from neural_compressor.torch import quantize - - fp32_model = build_simple_torch_model() - qmodel = quantize(fp32_model, quant_config) - self.assertIsNotNone(qmodel) - return qmodel - - def test_rtn(self): - from neural_compressor.torch import RTNWeightQuantConfig - - # some tests were skipped to accelerate the CI - rnt_options = { - "weight_dtype": ["int", "int8", "nf4", "fp4_e2m1_bnb"], - "weight_bits": [4, 1, 8], - "weight_group_size": [32, -1, 1, 512], - "weight_sym": [True, False], - "act_dtype": ["fp32"], - "enable_full_range": [False, True], - "enable_mse_search": [False], - "group_dim": [1, 0], - "return_int": [False, True], - } - from itertools import product - - keys = RTNWeightQuantConfig.params_list - for value in product(*rnt_options.values()): - d = dict(zip(keys, value)) - if (d["weight_dtype"] == "int" and d["weight_bits"] != 8) or ( - d["enable_full_range"] - and d["enable_mse_search"] - or (d["return_int"] and (d["group_dim"] != 1 or d["weight_bits"] != 8)) - ): - continue - quant_config = RTNWeightQuantConfig(**d) - self._apply_rtn(quant_config) - - def test_rtn_return_type(self): - from neural_compressor.torch import RTNWeightQuantConfig - - for return_int in [True, False]: - quant_config = RTNWeightQuantConfig(return_int=return_int) - qmodel = self._apply_rtn(quant_config) - - def test_rtn_mse_search(self): - from neural_compressor.torch import RTNWeightQuantConfig - - quant_config = RTNWeightQuantConfig(enable_mse_search=True) - qmodel = self._apply_rtn(quant_config) - - def test_rtn_recover(self): - from neural_compressor.torch import RTNWeightQuantConfig - - quant_config = RTNWeightQuantConfig(return_int=True) - qmodel = self._apply_rtn(quant_config) - input = torch.randn(4, 8) - # test forward - out = qmodel(input) - recovered_fc1 = qmodel.fc1.recover() - self.assertIsNotNone(recovered_fc1) - - def test_weight_only_linear(self): - from neural_compressor.torch.algorithms.weight_only.rtn import rtn_quantize - - model = build_simple_torch_model() - options = { - "compression_dtype": [torch.int8, torch.int16, torch.int32, torch.int64], - "compression_dim": [0, 1], - "module": [model.fc1, model.fc2, model.fc3, model.fc4], - } - from itertools import product - - for compression_dtype, compression_dim, module in product(*options.values()): - q_model = rtn_quantize( - model=module, - return_int=True, - compression_dtype=compression_dtype, - compression_dim=compression_dim, - ) + def test_export_compressed_model(self): + model = copy.deepcopy(self.tiny_gptj) + quant_config = RTNConfig(export_compressed_model=True) + model = quantize(model, quant_config) + print(model) if __name__ == "__main__": diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index 876311355f1..08c6cfb3171 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -62,12 +62,12 @@ def setUp(self): def test_autotune_api(self): logger.info("test_autotune_api") from neural_compressor.common.base_tune import tuning_objectives - from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune def eval_acc_fn(model) -> float: return 1.0 - custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) best_model = autotune( model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] ) @@ -78,7 +78,7 @@ def eval_acc_fn(model) -> float: def test_autotune_api_2(self): logger.info("test_autotune_api") from neural_compressor.common.base_tune import tuning_objectives - from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune def eval_acc_fn(model) -> float: return 1.0 @@ -94,7 +94,7 @@ def eval_perf_fn(model) -> float: }, ] - custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_fns) self.assertIsNone(best_model) self.assertEqual(len(tuning_objectives.eval_fn_registry), 2) @@ -102,9 +102,9 @@ def eval_perf_fn(model) -> float: @reset_tuning_target def test_autotune_not_eval_func(self): logger.info("test_autotune_api") - from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune - custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) # Use assertRaises to check that an AssertionError is raised with self.assertRaises(AssertionError) as context: diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 0e0925685c0..ba2cdd95b52 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -1,12 +1,11 @@ import copy import unittest +import torch import transformers -from neural_compressor.common.logger import Logger - -logger = Logger().get_logger() -import torch +from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, quantize +from neural_compressor.torch.utils import logger def build_simple_torch_model(): @@ -68,20 +67,16 @@ def test_quantize_rtn_from_dict_beginner(self): self.assertIsNotNone(qmodel) def test_quantize_rtn_from_class_beginner(self): - from neural_compressor.torch import RTNWeightQuantConfig, quantize - - quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4", weight_group_size=32) + quant_config = RTNConfig(weight_bits=4, weight_dtype="nf4", weight_group_size=32) fp32_model = build_simple_torch_model() qmodel = quantize(fp32_model, quant_config) self.assertIsNotNone(qmodel) def test_quantize_rtndq_from_class_beginner(self): - from neural_compressor.torch import RTNWeightQuantConfig, quantize - - fp32_config = RTNWeightQuantConfig(weight_dtype="fp32") + fp32_config = RTNConfig(weight_dtype="fp32") fp32_model = copy.deepcopy(self.gptj) - quant_config = RTNWeightQuantConfig( + quant_config = RTNConfig( weight_bits=4, weight_dtype="int", weight_sym=False, @@ -96,7 +91,7 @@ def test_quantize_rtndq_from_class_beginner(self): from neural_compressor.torch.utils.utility import get_double_quant_config double_quant_config_dict = get_double_quant_config("GGML_TYPE_Q4_K", weight_sym=False) - quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict) + quant_config = RTNConfig.from_dict(double_quant_config_dict) quant_config.set_local("lm_head", fp32_config) qmodel = quantize(fp32_model, quant_config) out3 = qmodel(self.lm_input) @@ -104,7 +99,7 @@ def test_quantize_rtndq_from_class_beginner(self): fp32_model = copy.deepcopy(self.gptj) - quant_config = RTNWeightQuantConfig( + quant_config = RTNConfig( weight_bits=4, weight_dtype="nf4", weight_group_size=32, @@ -116,7 +111,7 @@ def test_quantize_rtndq_from_class_beginner(self): fp32_model = copy.deepcopy(self.gptj) # bitsandbytes double quant setting double_quant_config_dict = get_double_quant_config("BNB") - quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict) + quant_config = RTNConfig.from_dict(double_quant_config_dict) quant_config.set_local("lm_head", fp32_config) qmodel = quantize(fp32_model, quant_config) out5 = qmodel(self.lm_input) @@ -145,11 +140,9 @@ def test_quantize_rtn_from_dict_advance(self): self.assertIsNotNone(qmodel) def test_quantize_rtn_from_class_advance(self): - from neural_compressor.torch import RTNWeightQuantConfig, quantize - - quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") + quant_config = RTNConfig(weight_bits=4, weight_dtype="nf4") # set operator instance - fc1_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="int8") + fc1_config = RTNConfig(weight_bits=4, weight_dtype="int8") quant_config.set_local("model.fc1", fc1_config) # get model and quantize fp32_model = build_simple_torch_model() @@ -157,23 +150,20 @@ def test_quantize_rtn_from_class_advance(self): self.assertIsNotNone(qmodel) def test_config_white_lst(self): - from neural_compressor.torch import RTNWeightQuantConfig, quantize - - global_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") + global_config = RTNConfig(weight_bits=4, weight_dtype="nf4") # set operator instance - fc1_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="int8", white_list=["model.fc1"]) + fc1_config = RTNConfig(weight_bits=4, weight_dtype="int8", white_list=["model.fc1"]) # get model and quantize fp32_model = build_simple_torch_model() qmodel = quantize(fp32_model, quant_config=global_config + fc1_config) self.assertIsNotNone(qmodel) def test_config_white_lst2(self): - from neural_compressor.torch import RTNWeightQuantConfig from neural_compressor.torch.utils.utility import get_model_info - global_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") + global_config = RTNConfig(weight_bits=4, weight_dtype="nf4") # set operator instance - fc1_config = RTNWeightQuantConfig(weight_bits=6, weight_dtype="int8", white_list=["fc1"]) + fc1_config = RTNConfig(weight_bits=6, weight_dtype="int8", white_list=["fc1"]) quant_config = global_config + fc1_config # get model and quantize fp32_model = build_simple_torch_model() @@ -185,8 +175,6 @@ def test_config_white_lst2(self): self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].weight_bits == 4) def test_config_from_dict(self): - from neural_compressor.torch import RTNWeightQuantConfig - quant_config = { "rtn_weight_only_quant": { "global": { @@ -202,22 +190,18 @@ def test_config_from_dict(self): }, } } - config = RTNWeightQuantConfig.from_dict(quant_config["rtn_weight_only_quant"]) + config = RTNConfig.from_dict(quant_config["rtn_weight_only_quant"]) self.assertIsNotNone(config.local_config) def test_config_to_dict(self): - from neural_compressor.torch import RTNWeightQuantConfig - - quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") - fc1_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="int8") + quant_config = RTNConfig(weight_bits=4, weight_dtype="nf4") + fc1_config = RTNConfig(weight_bits=4, weight_dtype="int8") quant_config.set_local("model.fc1", fc1_config) config_dict = quant_config.to_dict() self.assertIn("global", config_dict) self.assertIn("local", config_dict) def test_same_type_configs_addition(self): - from neural_compressor.torch import RTNWeightQuantConfig - quant_config1 = { "rtn_weight_only_quant": { "weight_dtype": "nf4", @@ -225,7 +209,7 @@ def test_same_type_configs_addition(self): "weight_group_size": 32, }, } - q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + q_config = RTNConfig.from_dict(quant_config1["rtn_weight_only_quant"]) quant_config2 = { "rtn_weight_only_quant": { "global": { @@ -240,7 +224,7 @@ def test_same_type_configs_addition(self): }, } } - q_config2 = RTNWeightQuantConfig.from_dict(quant_config2["rtn_weight_only_quant"]) + q_config2 = RTNConfig.from_dict(quant_config2["rtn_weight_only_quant"]) q_config3 = q_config + q_config2 q3_dict = q_config3.to_dict() for op_name, op_config in quant_config2["rtn_weight_only_quant"]["local"].items(): @@ -251,8 +235,6 @@ def test_same_type_configs_addition(self): ) def test_diff_types_configs_addition(self): - from neural_compressor.torch import GPTQConfig, RTNWeightQuantConfig - quant_config1 = { "rtn_weight_only_quant": { "weight_dtype": "nf4", @@ -260,7 +242,7 @@ def test_diff_types_configs_addition(self): "weight_group_size": 32, }, } - q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + q_config = RTNConfig.from_dict(quant_config1["rtn_weight_only_quant"]) d_config = GPTQConfig(double_quant_bits=4) combined_config = q_config + d_config combined_config_d = combined_config.to_dict() @@ -269,8 +251,6 @@ def test_diff_types_configs_addition(self): self.assertIn("gptq", combined_config_d) def test_composable_config_addition(self): - from neural_compressor.torch import GPTQConfig, RTNWeightQuantConfig - quant_config1 = { "rtn_weight_only_quant": { "weight_dtype": "nf4", @@ -278,7 +258,7 @@ def test_composable_config_addition(self): "weight_group_size": 32, }, } - q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + q_config = RTNConfig.from_dict(quant_config1["rtn_weight_only_quant"]) d_config = GPTQConfig(double_quant_bits=4) combined_config = q_config + d_config combined_config_d = combined_config.to_dict() @@ -289,12 +269,11 @@ def test_composable_config_addition(self): combined_config3 = combined_config + combined_config2 def test_config_mapping(self): - from neural_compressor.torch import RTNWeightQuantConfig from neural_compressor.torch.utils.utility import get_model_info - quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") + quant_config = RTNConfig(weight_bits=4, weight_dtype="nf4") # set operator instance - fc1_config = RTNWeightQuantConfig(weight_bits=6, weight_dtype="int8") + fc1_config = RTNConfig(weight_bits=6, weight_dtype="int8") quant_config.set_local("fc1", fc1_config) # get model and quantize fp32_model = build_simple_torch_model() @@ -305,7 +284,7 @@ def test_config_mapping(self): self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].weight_bits == 6) self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].weight_bits == 4) # test regular matching - fc_config = RTNWeightQuantConfig(weight_bits=5, weight_dtype="int8") + fc_config = RTNConfig(weight_bits=5, weight_dtype="int8") quant_config.set_local("fc", fc_config) configs_mapping = quant_config.to_config_mapping(model_info=model_info) logger.info(configs_mapping) @@ -327,10 +306,9 @@ def test_gptq_config(self): class TestQuantConfigForAutotune(unittest.TestCase): def test_expand_config(self): # test the expand functionalities, the user is not aware it - from neural_compressor.torch import RTNWeightQuantConfig - tune_config = RTNWeightQuantConfig(weight_bits=[4, 6]) - expand_config_list = RTNWeightQuantConfig.expand(tune_config) + tune_config = RTNConfig(weight_bits=[4, 6]) + expand_config_list = RTNConfig.expand(tune_config) self.assertEqual(expand_config_list[0].weight_bits, 4) self.assertEqual(expand_config_list[1].weight_bits, 6)