diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index de5d70c6d09..addc6490ed5 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -93,18 +93,19 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear, **Export arguments** | export args | default value | comments | |:----------:|:-------------:|:-------------------------------------------------------------------:| -| qweight_config_path | None | If need to export model with fp32_model and json file, set the path of qconfig.json | +| use_optimum_format | True | Whether to use the popular format used in [Optimum](https://github.com/huggingface/optimum/blob/e0927976d06d163ed09fe5bd80d013e1cfa0c463/docs/source/llm_quantization/usage_guides/quantization.mdx#L5) | | sym_full_range | False | Whether to leverage the full compression range under symmetric quantization | -| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] | -| compression_dim | 1 | 0 means output channel while 1 means input channel | -| scale_dtype | torch.float32 | Data type for scale and bias | -| use_hf_format | False | Whether to use the popular format present on HuggingFace hub | +| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64]. It's torch.int32 when use_optimum_format=True | +| compression_dim | 1 | 0 means output channel while 1 means input channel. It's 1 for weight and 0 for zero-point when use_optimum_format=True | +| scale_dtype | torch.float32 | Data type for scale and bias. It's torch.float16 when use_optimum_format=True | +| qweight_config_path | None | set the path of qconfig.json if you want to export model with json file | +| gptq_config_path | None | If need to export model with fp32_model and json file, set the path of gptq_config.json for GPTQ quantized model| -**Note:** HuggingFace format is quite special, the main differences are as follows: +**Note:** The format used in Optimum is acceptable for transformers, which makes it easy to use. However, this format is rather special, the main differences are as follows: > 1: Compression Dimension: weight = 1, zero = 0 and both are transposed. > 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym. -> 3: Group Index: Use the same number for a group instead of recording channel order. +> 3: Group Index: Use the same number for a group instead of recording channel order. ### **User Code Example** diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 1710cd8cae5..a3cb3f1ea09 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4582,10 +4582,12 @@ def rtn_quantize(self, model, tune_cfg): enable_full_range = self.recipes["rtn_args"].get("enable_full_range", False) enable_mse_search = self.recipes["rtn_args"].get("enable_mse_search", False) group_dim = self.recipes["rtn_args"].get("group_dim", 1) + return_int = self.recipes["rtn_args"].get("return_int", False) else: # pragma: no cover enable_full_range = False enable_mse_search = False group_dim = 1 + return_int = False from .torch_utils.util import fetch_module, set_module from .torch_utils.weight_only import rtn_quantize @@ -4623,7 +4625,7 @@ def rtn_quantize(self, model, tune_cfg): num_bits, group_size, scheme, - return_int=False, + return_int=return_int, data_type=dtype, enable_full_range=enable_full_range, enable_mse_search=enable_mse_search, diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index 57103566d9d..6e9df2d5392 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -217,10 +217,10 @@ def __init__( compression_dim=1, g_idx=False, device="cpu", - use_hf_format=False, + use_optimum_format=True, ): super().__init__() - self.use_hf_format = use_hf_format + self.use_optimum_format = use_optimum_format self.dtype = dtype if "int" not in self.dtype: # for nf4, fp4 from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING @@ -245,13 +245,13 @@ def __init__( dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} self.compress_bits = dtype_bits_mapping[compression_dtype] self.n_pack = self.compress_bits // self.bits - self.compressed_dtype = compression_dtype - self.float_type = scale_dtype # K is input channel, N is output channel assert compression_dim in [0, 1], ( "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." ) - if self.use_hf_format: + if self.use_optimum_format: + self.float_type = torch.float16 + self.compressed_dtype = torch.int32 self.register_buffer( "scales", torch.zeros( @@ -276,7 +276,10 @@ def __init__( ).to(device), ) self.qzeros = self.qzeros.T + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: + self.compressed_dtype = compression_dtype + self.float_type = scale_dtype self.register_buffer( "scales", torch.zeros( @@ -316,18 +319,18 @@ def __init__( dtype=self.compressed_dtype, ).to(device), ) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.bias = None if g_idx: self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) else: self.g_idx = None - if bias: - self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) - else: - self.bias = None def pack(self, int_weight, scale, zp, bias, g_idx=None): int_weight = int_weight.to(self.device) - if self.use_hf_format and zp is None: + if self.use_optimum_format and zp is None: # to avoid overflow int_weight = int_weight.type(torch.int32) shift_bias = 2 ** (self.bits - 1) @@ -339,13 +342,13 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): if g_idx is not None: assert hasattr(self, "g_idx"), "g_idx is not set when initializing." self.g_idx = g_idx.type(torch.int32).to(self.device) - if self.use_hf_format: + if self.use_optimum_format: invperm = torch.argsort(self.g_idx) self.g_idx = invperm // self.groupsize self.g_idx = self.g_idx.type(torch.int32).to(self.device) assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: int_weight = int_weight.T self.qweight = self.qweight.T origin_shape = int_weight.shape @@ -362,14 +365,14 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: self.qweight = self.qweight.T if zp is not None: zp = zp.to(self.device) - if self.use_hf_format: + if self.use_optimum_format: zp -= 1 - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T self.qzeros = self.qzeros.T assert hasattr(self, "qzeros"), "zp is not set when initializing." @@ -382,23 +385,19 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: self.qzeros = self.qzeros.T - if self.use_hf_format: + if self.use_optimum_format: self.scales = self.scales.T self.qweight = self.qweight.T - self.g_idx = self.g_idx self.qzeros = self.qzeros.T def recover(self): logger.debug(f"Recovering {self} weight") - if self.use_hf_format: - # Prevent broken id links of self.scales and self.scales - self.scales = self.scales.T - self.qweight = self.qweight.T - self.g_idx = self.g_idx - self.qzeros = self.qzeros.T - device = self.scales.device + scales = self.scales.T if self.use_optimum_format else self.scales + qweight = self.qweight.T if self.use_optimum_format else self.qweight + + device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) if self.g_idx is None: # used for recovering fp32_weight @@ -410,8 +409,7 @@ def recover(self): weight_dtype = torch.int8 # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) - qweight = self.qweight - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: weight = weight.T qweight = qweight.T origin_shape = weight.shape @@ -427,7 +425,7 @@ def recover(self): if weight_dtype == torch.uint8: tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: weight = weight.T if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) @@ -437,9 +435,9 @@ def recover(self): # unpack zero_point if hasattr(self, "qzeros"): zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp - zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros - if self.use_hf_format or self.compression_dim == 0: + zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) + qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T qzeros = qzeros.T origin_shape = zp.shape @@ -454,30 +452,34 @@ def recover(self): tmp = tmp >> self.compress_bits - self.bits tmp &= mask zp[:, index] = tmp.type(zp_dtype) - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T - if self.use_hf_format: + if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 zp = torch.where(zp > (2**self.bits - 1), 0, zp) # recover fp32 weight with int_weight, scale, and zero_point for idx in range(self.in_features): - fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]] + fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]] else: # recover fp32 weight with int_weight, scale for idx in range(self.in_features): - fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]] + fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]] return fp32_weight def forward(self, input): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None if level == DEBUG: if not hasattr(self, "weight"): - self.weight = self.recover() + self.weight = weight input = input.type(self.weight.dtype) logger.debug(f"Calculating {self}") return F.linear(input, self.weight, self.bias) else: - weight = self.recover() input = input.type(weight.dtype) return F.linear(input, weight, self.bias) @@ -489,8 +491,8 @@ def extra_repr(self) -> str: self.groupsize, self.bias is not None, ) - if self.use_hf_format: - tmp_str += ", use_hf_format=True" + if self.use_optimum_format: + tmp_str += ", use_optimum_format=True" return tmp_str diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index eb404d139f8..c29994f7755 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -396,7 +396,7 @@ def rtn_quantize( compression_dim = kwargs.get("compression_dim", 1) scale_dtype = kwargs.get("scale_dtype", torch.float32) device = kwargs.get("device", "cpu") - use_hf_format = kwargs.get("use_hf_format", False) + use_optimum_format = kwargs.get("use_optimum_format", True) for name, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: continue @@ -452,7 +452,7 @@ def rtn_quantize( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, - use_hf_format=use_hf_format, + use_optimum_format=use_optimum_format, ) new_module.pack(int_weight, scale, zp, m.bias) if name == "": diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index fb7046a1607..395b9c007fe 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -459,7 +459,7 @@ def export_compressed_model( scale_dtype=torch.float32, gptq_config_path=None, device="cpu", - use_hf_format=False, + use_optimum_format=True, ): """Convert Linear to WeightOnlyLinear for low memory inference. @@ -475,7 +475,7 @@ def export_compressed_model( Defaults to torch.float32. gptq_config_path (str, optional): Path of gptq_config.json. Defaults to None. device (str, optional): choose device for compression. Defaults to cpu. - use_hf_format (bool, optional): use the popular huggingface compression format. + use_optimum_format (bool, optional): use the popular huggingface compression format. 1: compression_dim: weight = 1, zeros = 0 and both are transposed. 2: zeros -= 1 before compression. Why we need it? 3: g_idx: use same number for one group instead of recording the channel order. @@ -520,7 +520,7 @@ def export_compressed_model( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, - use_hf_format=use_hf_format, + use_optimum_format=use_optimum_format, ) set_module(self.model, k, new_module) continue @@ -551,7 +551,7 @@ def export_compressed_model( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, - use_hf_format=use_hf_format, + use_optimum_format=use_optimum_format, ) new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm) set_module(self.model, k, new_module) @@ -578,7 +578,7 @@ def export_compressed_model( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, - use_hf_format=use_hf_format, + use_optimum_format=use_optimum_format, ) set_module(self.model, k, mod) return self.model diff --git a/neural_compressor/torch/quantization/modules.py b/neural_compressor/torch/quantization/modules.py index 6dd646fe6ae..ccba214e0f8 100644 --- a/neural_compressor/torch/quantization/modules.py +++ b/neural_compressor/torch/quantization/modules.py @@ -134,10 +134,10 @@ def __init__( compression_dim=1, g_idx=False, device="cpu", - use_hf_format=False, + use_optimum_format=True, ): super().__init__() - self.use_hf_format = use_hf_format + self.use_optimum_format = use_optimum_format self.dtype = dtype if "int" not in self.dtype: # for nf4, fp4 from neural_compressor.torch.algorithms.weight_only.rtn import FLOAT_MAPPING, INT_MAPPING @@ -162,13 +162,13 @@ def __init__( dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} self.compress_bits = dtype_bits_mapping[compression_dtype] self.n_pack = self.compress_bits // self.bits - self.compressed_dtype = compression_dtype - self.float_type = scale_dtype # K is input channel, N is output channel assert compression_dim in [0, 1], ( "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." ) - if self.use_hf_format: + if self.use_optimum_format: + self.float_type = torch.float16 + self.compressed_dtype = torch.int32 self.register_buffer( "scales", torch.zeros( @@ -193,7 +193,10 @@ def __init__( ).to(device), ) self.qzeros = self.qzeros.T + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: + self.compressed_dtype = compression_dtype + self.float_type = scale_dtype self.register_buffer( "scales", torch.zeros( @@ -233,18 +236,18 @@ def __init__( dtype=self.compressed_dtype, ).to(device), ) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.bias = None if g_idx: self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) else: self.g_idx = None - if bias: - self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) - else: - self.bias = None def pack(self, int_weight, scale, zp, bias, g_idx=None): int_weight = int_weight.to(self.device) - if self.use_hf_format and zp is None: + if self.use_optimum_format and zp is None: # to avoid overflow int_weight = int_weight.type(torch.int32) shift_bias = 2 ** (self.bits - 1) @@ -256,13 +259,13 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): if g_idx is not None: assert hasattr(self, "g_idx"), "g_idx is not set when initializing." self.g_idx = g_idx.type(torch.int32).to(self.device) - if self.use_hf_format: + if self.use_optimum_format: invperm = torch.argsort(self.g_idx) self.g_idx = invperm // self.groupsize self.g_idx = self.g_idx.type(torch.int32).to(self.device) assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: int_weight = int_weight.T self.qweight = self.qweight.T origin_shape = int_weight.shape @@ -279,14 +282,14 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: self.qweight = self.qweight.T if zp is not None: zp = zp.to(self.device) - if self.use_hf_format: + if self.use_optimum_format: zp -= 1 - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T self.qzeros = self.qzeros.T assert hasattr(self, "qzeros"), "zp is not set when initializing." @@ -299,23 +302,19 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: self.qzeros = self.qzeros.T - if self.use_hf_format: + if self.use_optimum_format: self.scales = self.scales.T self.qweight = self.qweight.T - self.g_idx = self.g_idx self.qzeros = self.qzeros.T def recover(self): logger.debug(f"Recovering {self} weight") - if self.use_hf_format: - # Prevent broken id links of self.scales and self.scales - self.scales = self.scales.T - self.qweight = self.qweight.T - self.g_idx = self.g_idx - self.qzeros = self.qzeros.T - device = self.scales.device + scales = self.scales.T if self.use_optimum_format else self.scales + qweight = self.qweight.T if self.use_optimum_format else self.qweight + + device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) if self.g_idx is None: # used for recovering fp32_weight @@ -327,8 +326,7 @@ def recover(self): weight_dtype = torch.int8 # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) - qweight = self.qweight - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: weight = weight.T qweight = qweight.T origin_shape = weight.shape @@ -344,7 +342,7 @@ def recover(self): if weight_dtype == torch.uint8: tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) - if not self.use_hf_format and self.compression_dim == 0: + if not self.use_optimum_format and self.compression_dim == 0: weight = weight.T if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) @@ -354,9 +352,9 @@ def recover(self): # unpack zero_point if hasattr(self, "qzeros"): zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp - zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros - if self.use_hf_format or self.compression_dim == 0: + zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) + qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T qzeros = qzeros.T origin_shape = zp.shape @@ -371,30 +369,34 @@ def recover(self): tmp = tmp >> self.compress_bits - self.bits tmp &= mask zp[:, index] = tmp.type(zp_dtype) - if self.use_hf_format or self.compression_dim == 0: + if self.use_optimum_format or self.compression_dim == 0: zp = zp.T - if self.use_hf_format: + if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 zp = torch.where(zp > (2**self.bits - 1), 0, zp) # recover fp32 weight with int_weight, scale, and zero_point for idx in range(self.in_features): - fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]] + fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]] else: # recover fp32 weight with int_weight, scale for idx in range(self.in_features): - fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]] + fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]] return fp32_weight def forward(self, input): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None if level == DEBUG: if not hasattr(self, "weight"): - self.weight = self.recover() + self.weight = weight input = input.type(self.weight.dtype) logger.debug(f"Calculating {self}") return F.linear(input, self.weight, self.bias) else: - weight = self.recover() input = input.type(weight.dtype) return F.linear(input, weight, self.bias) @@ -406,8 +408,8 @@ def extra_repr(self) -> str: self.groupsize, self.bias is not None, ) - if self.use_hf_format: - tmp_str += ", use_hf_format=True" + if self.use_optimum_format: + tmp_str += ", use_optimum_format=True" return tmp_str diff --git a/neural_compressor/utils/load_huggingface.py b/neural_compressor/utils/load_huggingface.py index fff4c050603..43b68ef4c47 100644 --- a/neural_compressor/utils/load_huggingface.py +++ b/neural_compressor/utils/load_huggingface.py @@ -235,7 +235,7 @@ def save_for_huggingface_upstream(model, tokenizer, output_dir): def export_compressed_model( model, saved_dir=None, - use_hf_format=False, + use_optimum_format=True, enable_full_range=False, compression_dtype=torch.int32, compression_dim=1, @@ -247,7 +247,7 @@ def export_compressed_model( Args: model (torch.nn.Module): origin fp32 model. saved_dir (_type_, optional): the dir path of compression info. Defaults to None. - use_hf_format (bool, optional): whether use HuggingFace format. Defaults to False. + use_optimum_format (bool, optional): whether use HuggingFace format. Defaults to True. enable_full_range (bool, optional): Whether to leverage the full compression range under symmetric quantization. Defaults to False. compression_dtype (torch.Tensor, optional): The target dtype after comoression. @@ -277,6 +277,6 @@ def export_compressed_model( scale_dtype=scale_dtype, gptq_config_path=gptq_config_path, device=device, - use_hf_format=use_hf_format, + use_optimum_format=use_optimum_format, ) return inc_model.model diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index 47202b86b52..a2da94ac822 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -88,7 +88,7 @@ def test_RTN_int_quant(self): out2 = q_model(input) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) self.assertFalse(torch.all(out1 == out2)) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out3 = compressed_model(input) self.assertTrue("fc1.qweight" in compressed_model.state_dict().keys()) self.assertTrue("fc1.qzeros" not in compressed_model.state_dict().keys()) @@ -99,13 +99,14 @@ def test_RTN_int_quant(self): model = Model() new_model = load("saved", model, weight_only=True) inc_model = INCModel(new_model) - inc_model.export_compressed_model(qweight_config_path="saved/qconfig.json", use_hf_format=True) + inc_model.export_compressed_model(qweight_config_path="saved/qconfig.json", use_optimum_format=True) out4 = inc_model.model(input) self.assertTrue("fc1.qzeros" in inc_model.model.state_dict().keys()) model = Model() - compressed_model = export_compressed_model(model, saved_dir="saved", use_hf_format=True) + compressed_model = export_compressed_model(model, saved_dir="saved", use_optimum_format=True) self.assertTrue("fc1.qzeros" in inc_model.model.state_dict().keys()) - self.assertTrue(torch.all(out3 == out4)) + # output gap is because of torch.float16 is used in hf_format + self.assertTrue(torch.allclose(out3, out4, atol=1e-3)) model = Model() out1 = model(input) @@ -120,7 +121,7 @@ def test_RTN_int_quant(self): out2 = q_model(input) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) self.assertFalse(torch.all(out1 == out2)) - compressed_model = q_model.export_compressed_model(enable_full_range=True) + compressed_model = q_model.export_compressed_model(use_optimum_format=False, enable_full_range=True) out3 = compressed_model(input) self.assertTrue(torch.all(out3 == out2)) @@ -181,6 +182,7 @@ def test_RTN_int_quant(self): ) q_model = quantization.fit(model, conf, eval_func=eval_func) out2 = q_model(input) + self.assertTrue(isinstance(q_model.model.fc1, WeightOnlyLinear)) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) self.assertFalse(torch.all(out1 == out2)) @@ -245,7 +247,7 @@ def test_RTN_int_quant(self): model_size1 = os.path.getsize("saved/best_model.pt") / 1024 print("FP32 Model size:{:.3f}M".format(model_size1)) inc_model = INCModel(new_model) - inc_model.export_compressed_model(qweight_config_path="saved/qconfig.json") + inc_model.export_compressed_model(use_optimum_format=False, qweight_config_path="saved/qconfig.json") torch.save(inc_model.state_dict(), "saved/tmp.pt") model_size2 = os.path.getsize("saved/tmp.pt") / 1024 print("WeightOnlyLinear Model size:{:.3f}M".format(model_size2)) @@ -273,7 +275,7 @@ def test_RTN_4bit_quant(self): out2 = q_model(self.lm_input) self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1))) self.assertFalse(torch.all(out1[0] == out2[0])) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out3 = compressed_model(self.lm_input) self.assertTrue(torch.all(out3[0] == out2[0])) @@ -324,7 +326,7 @@ def test_AWQ_quant(self): fp32_model = copy.deepcopy(self.gptj) reload_model = load("saved", fp32_model, weight_only=True) out2 = reload_model(input) - q_model.export_compressed_model() + q_model.export_compressed_model(use_optimum_format=False) out3 = q_model(input) # no idea about the gap at 1e-08, use allclose instead of out1==out2 self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -428,7 +430,7 @@ def test_AWQ_nf4_quant(self): ) out2 = q_model(input) self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out3 = compressed_model(input) self.assertTrue(torch.all(out3[0] == out2[0])) @@ -529,7 +531,7 @@ def __iter__(self): q_model.save("saved") out1 = q_model.model(input) self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -554,10 +556,13 @@ def __iter__(self): ) q_model.save("saved") out1 = q_model.model(input) - compressed_model = q_model.export_compressed_model(use_hf_format=True) + compressed_model = q_model.export_compressed_model(use_optimum_format=True) out2 = compressed_model(input) + print(out1[0]) + print(out2[0]) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") - self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) + # hf_format uses fp16 for scale, so output atol is higher. + self.assertTrue(torch.allclose(out1[0], out2[0], atol=2e-04)) # # case 2: list or tuple model_3 = copy.deepcopy(self.gptj) @@ -569,7 +574,7 @@ def __iter__(self): ) q_model.save("saved") out1 = q_model.model(input) - compressed_model = q_model.export_compressed_model(use_hf_format=True) + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -650,7 +655,8 @@ def __iter__(self): compressed_model = q_model.export_compressed_model() out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") - self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) + # hf_format uses fp16 for scale, so output atol is higher. + self.assertTrue(torch.allclose(out1[0], out2[0], atol=2e-04)) # # case 2: list or tuple model_2 = copy.deepcopy(self.gptj) @@ -662,7 +668,7 @@ def __iter__(self): ) q_model.save("saved") out1 = q_model.model(input) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_optimum_format=False) out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -680,7 +686,8 @@ def __iter__(self): compressed_model = q_model.export_compressed_model() out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") - self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) + # hf_format uses fp16 for scale, so output atol is higher. + self.assertTrue(torch.allclose(out1[0], out2[0], atol=2e-04)) print("GPTQ with unfixed length Done") diff --git a/test/model/test_model_pytorch.py b/test/model/test_model_pytorch.py index 05edfd9c6fb..f0990b6558c 100644 --- a/test/model/test_model_pytorch.py +++ b/test/model/test_model_pytorch.py @@ -117,6 +117,8 @@ def test_WeightOnlyLinear(self): inc_model.export_compressed_model( qweight_config_path="saved/qconfig.json", compression_dtype=dtype, + scale_dtype=torch.float32, + use_optimum_format=False, ) out2 = q_model(input) torch.save(inc_model.state_dict(), "saved/tmp.pt") @@ -136,6 +138,7 @@ def test_WeightOnlyLinear(self): inc_model.export_compressed_model( qweight_config_path="saved/qconfig.json", compression_dim=dim, + use_optimum_format=False, ) out2 = q_model(input) torch.save(inc_model.state_dict(), "saved/tmp.pt") @@ -154,7 +157,6 @@ def test_WeightOnlyLinear(self): inc_model = INCModel(new_model) inc_model.export_compressed_model( qweight_config_path="saved/qconfig.json", - scale_dtype=torch.float16, ) out2 = q_model(input) torch.save(inc_model.state_dict(), "saved/tmp.pt")