diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index 062bfe414ee..5631dcc0917 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -20,6 +20,7 @@ apt-get install -y --no-install-recommends --fix-missing \ build-essential pip install -r /neural-compressor/requirements.txt +pip install -r /neural-compressor/requirements_pt.txt pip install cmake pip install torch \ diff --git a/neural_compressor/torch/algorithms/layer_wise/load.py b/neural_compressor/torch/algorithms/layer_wise/load.py index 09700044a8f..a883bfe3848 100644 --- a/neural_compressor/torch/algorithms/layer_wise/load.py +++ b/neural_compressor/torch/algorithms/layer_wise/load.py @@ -32,7 +32,7 @@ _open_zipfile_reader, ) -from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle +from neural_compressor.torch.algorithms.layer_wise import modified_pickle as pickle from .utils import torch diff --git a/neural_compressor/torch/algorithms/layer_wise/utils.py b/neural_compressor/torch/algorithms/layer_wise/utils.py index 464a25cdee0..bbe59de3fe4 100644 --- a/neural_compressor/torch/algorithms/layer_wise/utils.py +++ b/neural_compressor/torch/algorithms/layer_wise/utils.py @@ -27,10 +27,11 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from neural_compressor.common import options +from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear from .load import load -LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp") +LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir") class QDQLayer(torch.nn.Module): @@ -215,6 +216,9 @@ def _get_path(pretrained_model_name_or_path): return path +get_path = _get_path + + def load_value(model, param_name, path): if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True): input_embeddings = model.get_input_embeddings() @@ -281,6 +285,12 @@ def clean_module_weight(module): else: submodule = module + if isinstance(module, WeightOnlyLinear): + for n, m in submodule._buffers.items(): + old_value = getattr(submodule, n) + with torch.no_grad(): + submodule._buffers[n] = torch.zeros(old_value.shape, device="meta") + for n, m in submodule.named_parameters(): is_buffer = n in submodule._buffers old_value = getattr(submodule, n) diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 4c2df596282..eae9f7c3a84 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -230,11 +230,13 @@ def __init__( # device self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name() - self.model.to(self.device) + if not use_layer_wise: + self.model.to(self.device) self.is_ready = False self.use_layer_wise = use_layer_wise - self.model_path = model_path + if use_layer_wise: + self.prepare_layer_wise(model_path) # dataloader self.use_max_length = use_max_length @@ -243,6 +245,20 @@ def __init__( self.dataloader = [] self.nsamples = nsamples + def prepare_layer_wise(self, model_path): + import os + + from neural_compressor.torch.algorithms.layer_wise import LWQ_WORKSPACE, get_path, register_weight_hooks + + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + if model_path == "": + model_path = self.model.path + assert model_path, "model_path should not be None." + self.model_path = get_path(model_path) + register_weight_hooks( + self.model, self.model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE + ) + def get_full_layer_name(self, sub_layer_name, block_idx): transformer_name = self.gptq_related_blocks["transformers_name"] return ".".join([transformer_name, str(block_idx), sub_layer_name]) @@ -413,7 +429,6 @@ def execute_quantization(self, means=None, stds=None): # Step1: prepare quantization (calibration datasets) logger.info("Begin ====>") - model_path = self.model_path # Step2: run gptq quantization in a transformer block-wise manner. gptq_config = {} @@ -450,7 +465,7 @@ def execute_quantization(self, means=None, stds=None): if self.use_layer_wise: # pragma: no cover from neural_compressor.torch.algorithms.layer_wise import load_value - W = load_value(self.model, full_layer_name + ".weight", model_path) + W = load_value(self.model, full_layer_name + ".weight", self.model_path) else: W = sub_layers[layer_name].weight.data.clone() @@ -489,7 +504,7 @@ def tmp(_, inp, out): from neural_compressor.torch.algorithms.layer_wise import load_value full_layer_name = self.get_full_layer_name(layer_name, block_idx) - W = load_value(self.model, full_layer_name + ".weight", model_path) + W = load_value(self.model, full_layer_name + ".weight", self.model_path) else: W = sub_layers[layer_name].weight.data.clone() accelerator.mark_step() @@ -518,7 +533,7 @@ def tmp(_, inp, out): if n == "weight": set_module_tensor_to_device(self.model, param_name, self.device, Q) else: - value = load_value(self.model, param_name, model_path) + value = load_value(self.model, param_name, self.model_path) set_module_tensor_to_device(self.model, param_name, self.device, value) # sub_layer.weight.data = Q torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") @@ -562,7 +577,13 @@ def tmp(_, inp, out): gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] else: gptq_perm = None - Q = sub_layers[layer_name].weight.data + if self.use_layer_wise: + state_dict = torch.load(LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt") + Q = state_dict["weight"].data + bias = state_dict["bias"] if "bias" in state_dict.keys() else None + + else: + Q = sub_layers[layer_name].weight.data if weight_config_this_layer["act_order"]: Q.copy_(Q[:, gptq_perm]) if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): @@ -591,6 +612,9 @@ def tmp(_, inp, out): scale = scale.t_().contiguous() zp = zp.t_().contiguous() if zp is not None else zp + if not self.use_layer_wise: + bias = sub_layers[layer_name].bias + new_module = WeightOnlyLinear( in_features, out_features, @@ -598,11 +622,11 @@ def tmp(_, inp, out): bits=weight_config_this_layer["bits"], group_size=weight_config_this_layer["group_size"], zp=gptq_zp is not None, - bias=sub_layers[layer_name].bias is not None, + bias=bias is not None, g_idx=gptq_perm is not None, device=self.device, ) - new_module.pack(int_weight, gptq_scale, gptq_zp, sub_layers[layer_name].bias, gptq_perm) + new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm) set_module(transformer_block, layer_name, new_module) del gptq_for_this_block torch.cuda.empty_cache() @@ -1019,8 +1043,10 @@ def prepare( def convert(self, model, *args, **kwargs): self.gptq_quantizer.model = model self.gptq_quantizer.remove_prepare_for_calibration() + q_model, gptq_config = self.gptq_quantizer.execute_quantization() - q_model = q_model.to(self.model_device) + if not self.gptq_quantizer.use_layer_wise: + q_model = q_model.to(self.model_device) q_model.gptq_config = gptq_config logger.info("GPTQ quantizing done.") return q_model diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py index dcb2ff421f4..503a469b0c7 100644 --- a/neural_compressor/torch/algorithms/weight_only/modules.py +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -19,6 +19,7 @@ # since the model classes inherit torch.nn.Module. import math +import numba import numpy as np import torch from torch.autograd import Function @@ -175,7 +176,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): self.scales = self.scales.T.contiguous() self.qweight = self.qweight.T.contiguous() self.qzeros = self.qzeros.T.contiguous() - int_weight = int_weight.to(self.device) + if int_weight.device.type != "meta": + int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow int_weight = int_weight.type(torch.int32) @@ -300,7 +302,253 @@ def unpack_tensor_with_torch(self, packed_tensor): accelerator.synchronize() return unpacked_tensor - def pack_tensor_with_numpy(self, raw_tensor): + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b4_c32( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 7] & 0b1111) << 28) + | ((raw_array[:, i * n_pack + 6] & 0b1111) << 24) + | ((raw_array[:, i * n_pack + 5] & 0b1111) << 20) + | ((raw_array[:, i * n_pack + 4] & 0b1111) << 16) + | ((raw_array[:, i * n_pack + 3] & 0b1111) << 12) + | ((raw_array[:, i * n_pack + 2] & 0b1111) << 8) + | ((raw_array[:, i * n_pack + 1] & 0b1111) << 4) + | (raw_array[:, i * n_pack] & 0b1111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b4_c16( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 3] & 0b1111) << 12) + | ((raw_array[:, i * n_pack + 2] & 0b1111) << 8) + | ((raw_array[:, i * n_pack + 1] & 0b1111) << 4) + | (raw_array[:, i * n_pack] & 0b1111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b4_c8( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ((raw_array[:, i * n_pack + 1] & 0b1111) << 4) | (raw_array[:, i * n_pack] & 0b1111) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b4_c64( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 15] & 0b1111) << 60) + | ((raw_array[:, i * n_pack + 14] & 0b1111) << 56) + | ((raw_array[:, i * n_pack + 13] & 0b1111) << 52) + | ((raw_array[:, i * n_pack + 12] & 0b1111) << 48) + | ((raw_array[:, i * n_pack + 11] & 0b1111) << 44) + | ((raw_array[:, i * n_pack + 10] & 0b1111) << 40) + | ((raw_array[:, i * n_pack + 9] & 0b1111) << 36) + | ((raw_array[:, i * n_pack + 8] & 0b1111) << 32) + | ((raw_array[:, i * n_pack + 7] & 0b1111) << 28) + | ((raw_array[:, i * n_pack + 6] & 0b1111) << 24) + | ((raw_array[:, i * n_pack + 5] & 0b1111) << 20) + | ((raw_array[:, i * n_pack + 4] & 0b1111) << 16) + | ((raw_array[:, i * n_pack + 3] & 0b1111) << 12) + | ((raw_array[:, i * n_pack + 2] & 0b1111) << 8) + | ((raw_array[:, i * n_pack + 1] & 0b1111) << 4) + | (raw_array[:, i * n_pack] & 0b1111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b8_c32( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 3] & 0b11111111) << 24) + | ((raw_array[:, i * n_pack + 2] & 0b11111111) << 16) + | ((raw_array[:, i * n_pack + 1] & 0b11111111) << 8) + | (raw_array[:, i * n_pack] & 0b11111111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b8_c16( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 3] & 0b11111111) << 24) + | ((raw_array[:, i * n_pack + 2] & 0b11111111) << 16) + | ((raw_array[:, i * n_pack + 1] & 0b11111111) << 8) + | (raw_array[:, i * n_pack] & 0b11111111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b8_c8( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = raw_array[:, i * n_pack] & 0b11111111 + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b8_c64( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 7] & 0b11111111) << 56) + | ((raw_array[:, i * n_pack + 6] & 0b11111111) << 48) + | ((raw_array[:, i * n_pack + 5] & 0b11111111) << 40) + | ((raw_array[:, i * n_pack + 4] & 0b11111111) << 32) + | ((raw_array[:, i * n_pack + 3] & 0b11111111) << 24) + | ((raw_array[:, i * n_pack + 2] & 0b11111111) << 16) + | ((raw_array[:, i * n_pack + 1] & 0b11111111) << 8) + | (raw_array[:, i * n_pack] & 0b11111111) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b2_c32( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 15] & 0b11) << 30) + | ((raw_array[:, i * n_pack + 14] & 0b11) << 28) + | ((raw_array[:, i * n_pack + 13] & 0b11) << 26) + | ((raw_array[:, i * n_pack + 12] & 0b11) << 24) + | ((raw_array[:, i * n_pack + 11] & 0b11) << 22) + | ((raw_array[:, i * n_pack + 10] & 0b11) << 20) + | ((raw_array[:, i * n_pack + 9] & 0b11) << 18) + | ((raw_array[:, i * n_pack + 8] & 0b11) << 16) + | ((raw_array[:, i * n_pack + 7] & 0b11) << 14) + | ((raw_array[:, i * n_pack + 6] & 0b11) << 12) + | ((raw_array[:, i * n_pack + 5] & 0b11) << 10) + | ((raw_array[:, i * n_pack + 4] & 0b11) << 8) + | ((raw_array[:, i * n_pack + 3] & 0b11) << 6) + | ((raw_array[:, i * n_pack + 2] & 0b11) << 4) + | ((raw_array[:, i * n_pack + 1] & 0b11) << 2) + | (raw_array[:, i * n_pack] & 0b11) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b2_c16( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 7] & 0b11) << 14) + | ((raw_array[:, i * n_pack + 6] & 0b11) << 12) + | ((raw_array[:, i * n_pack + 5] & 0b11) << 10) + | ((raw_array[:, i * n_pack + 4] & 0b11) << 8) + | ((raw_array[:, i * n_pack + 3] & 0b11) << 6) + | ((raw_array[:, i * n_pack + 2] & 0b11) << 4) + | ((raw_array[:, i * n_pack + 1] & 0b11) << 2) + | (raw_array[:, i * n_pack] & 0b11) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b2_c8( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 3] & 0b11) << 6) + | ((raw_array[:, i * n_pack + 2] & 0b11) << 4) + | ((raw_array[:, i * n_pack + 1] & 0b11) << 2) + | (raw_array[:, i * n_pack] & 0b11) + ) + return packed_array + + @staticmethod + @numba.jit(nopython=True, parallel=True) + def pack_array_with_numba_b2_c64( + raw_array: np.ndarray, packed_array: np.ndarray, n_pack: int, new_in_features: int + ) -> np.ndarray: + for i in range(new_in_features): + packed_array[:, i] = ( + ((raw_array[:, i * n_pack + 31] & 0b11) << 62) + | ((raw_array[:, i * n_pack + 30] & 0b11) << 60) + | ((raw_array[:, i * n_pack + 29] & 0b11) << 58) + | ((raw_array[:, i * n_pack + 28] & 0b11) << 56) + | ((raw_array[:, i * n_pack + 27] & 0b11) << 54) + | ((raw_array[:, i * n_pack + 26] & 0b11) << 52) + | ((raw_array[:, i * n_pack + 25] & 0b11) << 50) + | ((raw_array[:, i * n_pack + 24] & 0b11) << 48) + | ((raw_array[:, i * n_pack + 23] & 0b11) << 46) + | ((raw_array[:, i * n_pack + 22] & 0b11) << 44) + | ((raw_array[:, i * n_pack + 21] & 0b11) << 42) + | ((raw_array[:, i * n_pack + 20] & 0b11) << 40) + | ((raw_array[:, i * n_pack + 19] & 0b11) << 38) + | ((raw_array[:, i * n_pack + 18] & 0b11) << 36) + | ((raw_array[:, i * n_pack + 17] & 0b11) << 34) + | ((raw_array[:, i * n_pack + 16] & 0b11) << 32) + | ((raw_array[:, i * n_pack + 15] & 0b11) << 30) + | ((raw_array[:, i * n_pack + 14] & 0b11) << 28) + | ((raw_array[:, i * n_pack + 13] & 0b11) << 26) + | ((raw_array[:, i * n_pack + 12] & 0b11) << 24) + | ((raw_array[:, i * n_pack + 11] & 0b11) << 22) + | ((raw_array[:, i * n_pack + 10] & 0b11) << 20) + | ((raw_array[:, i * n_pack + 9] & 0b11) << 18) + | ((raw_array[:, i * n_pack + 8] & 0b11) << 16) + | ((raw_array[:, i * n_pack + 7] & 0b11) << 14) + | ((raw_array[:, i * n_pack + 6] & 0b11) << 12) + | ((raw_array[:, i * n_pack + 5] & 0b11) << 10) + | ((raw_array[:, i * n_pack + 4] & 0b11) << 8) + | ((raw_array[:, i * n_pack + 3] & 0b11) << 6) + | ((raw_array[:, i * n_pack + 2] & 0b11) << 4) + | ((raw_array[:, i * n_pack + 1] & 0b11) << 2) + | (raw_array[:, i * n_pack] & 0b11) + ) + return packed_array + + def pack_array_with_numba( + self, raw_array: np.ndarray, n_pack: int, bits: int, compress_bits: int, compression_dtype=np.int32 + ) -> np.ndarray: + """Packs the input array by combining elements into a specified bit-width format using NumPy. + + Args: + raw_array (np.ndarray): The array to be packed. Shape: [out_features, in_features] or [1, in_features]. + n_pack (int): The number of elements to be packed together. + bits (int): The number of bits for each element. + compress_bits (int): The number of bits for each element of the compressed array, supported 2, 4, 8. + compression_dtype (np.dtype, optional): The data type of the compressed array. Defaults to np.int32. + + Returns: + np.ndarray: The packed array. + """ + out_features, in_features = raw_array.shape + new_in_features = (in_features + n_pack - 1) // n_pack + packed_array = np.zeros((out_features, new_in_features), dtype=compression_dtype) + raw_array = raw_array.astype(compression_dtype) + + pack_method_name = f"pack_array_with_numba_b{bits}_c{compress_bits}" + pack_method = getattr(self, pack_method_name) + return pack_method(raw_array, packed_array, n_pack, new_in_features) + + def pack_tensor_with_numpy_impl(self, raw_tensor): raw_array = raw_tensor.cpu().numpy() target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int) target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype @@ -318,6 +566,15 @@ def pack_tensor_with_numpy(self, raw_tensor): packed_tensor = torch.from_numpy(packed_array).to(device=raw_tensor.device) return packed_tensor + def pack_tensor_with_numpy(self, raw_tensor): + if self.bits not in [2, 4, 8]: + return self.pack_tensor_with_numpy_impl(raw_tensor) + compression_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype + packed_array = self.pack_array_with_numba( + raw_tensor.cpu().numpy(), self.n_pack, self.bits, self.compress_bits, compression_dtype + ) + return torch.from_numpy(packed_array).to(device=raw_tensor.device) + def unpack_tensor_with_numpy(self, packed_tensor): packed_array = packed_tensor.cpu().numpy() target_dtype = np.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else np.uint8 diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index 6a95bec4550..c04327a62f4 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -73,6 +73,8 @@ def convert( quantile=1.0, use_full_range=False, use_mse_search=False, + use_layer_wise=False, + model_path="", quant_lm_head=False, *args, **kwargs, @@ -122,7 +124,20 @@ def convert( "double_quant_group_size": kwargs.get("double_quant_group_size", 256), } use_optimum_format = kwargs.get("use_optimum_format", True) + + if use_layer_wise: + from neural_compressor.common.utils import DEFAULT_WORKSPACE + from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks + + if model_path == "": + model_path = model.path + assert model_path, "model_path should not be None." + model_path = get_path(model_path) + + register_weight_hooks(model, model_path, device=device, clean_weight=True) + for name, m in model.named_modules(): + if not isinstance(m, supported_layers): continue if name in weight_config: # pragma: no cover @@ -131,7 +146,8 @@ def convert( if dtype == "fp32": continue # Move modules to the accelerator device layer-by-layer - m.to(device) + if not use_layer_wise: + m.to(device) ### FP8 cast part if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]: logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name)) @@ -146,7 +162,6 @@ def convert( group_dim = weight_config[name]["group_dim"] use_full_range = weight_config[name]["use_full_range"] use_mse_search = weight_config[name]["use_mse_search"] - use_layer_wise = weight_config[name]["use_layer_wise"] use_optimum_format = kwargs.get("use_optimum_format", True) # double quant config double_quant_config = { @@ -171,6 +186,10 @@ def convert( continue logger.debug(f"RTN quantized module:{name, m}") logger.debug(log_msg) + + if use_layer_wise: + load_module(model, name, model_path, device=device) + # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. if is_transformers_imported(): transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) @@ -219,12 +238,17 @@ def convert( device=device, ) new_module.pack(int_weight, scale, zp, m.bias) + + if use_layer_wise: + m = m.to_empty(device=torch.device("meta")) if name == "": return new_module else: set_module(model, name, new_module) # Move modules back to the model device layer-by-layer - m.to(model_device) - new_module.to(model_device) - model.to(model_device) + if not use_layer_wise: + m.to(model_device) + new_module.to(model_device) + if not use_layer_wise: + model.to(model_device) return model diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index e5e0a1e627d..3b9872de3ec 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -83,16 +83,21 @@ def rtn_entry( "group_dim": quant_config.group_dim, "use_full_range": quant_config.use_full_range, "use_mse_search": quant_config.use_mse_search, - "use_layer_wise": quant_config.use_layer_wise, "use_double_quant": quant_config.use_double_quant, "double_quant_dtype": quant_config.double_quant_dtype, "double_quant_bits": quant_config.double_quant_bits, "double_quant_scheme": "sym" if quant_config.double_quant_use_sym else "asym", "double_quant_group_size": quant_config.double_quant_group_size, } - + kwargs.update( + { + "use_layer_wise": quant_config.use_layer_wise, + "model_path": quant_config.model_path, + "quant_lm_head": quant_config.quant_lm_head, + } + ) quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config) - model = quantizer.execute(model, mode=mode, quant_lm_head=quant_config.quant_lm_head) + model = quantizer.execute(model, mode=mode, *args, **kwargs) model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index f7686dd0a8a..40c51644866 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -217,7 +217,9 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - self.set_local(LM_HEAD_NAMES, RTNConfig(dtype="fp32")) + self.set_local( + LM_HEAD_NAMES, RTNConfig(dtype="fp32", use_layer_wise=self.use_layer_wise, model_path=self.model_path) + ) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -380,7 +382,9 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - self.set_local(LM_HEAD_NAMES, GPTQConfig(dtype="fp32")) + self.set_local( + LM_HEAD_NAMES, GPTQConfig(dtype="fp32", use_layer_wise=self.use_layer_wise, model_path=self.model_path) + ) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -402,7 +406,9 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig @classmethod def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "GPTQConfig"]: pre_defined_configs: Dict[torch_utils.ProcessorType, GPTQConfig] = {} - pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Client] = cls( + use_layer_wise=True + ) # , model_path=self.model_path) pre_defined_configs[torch_utils.ProcessorType.Server] = cls() return pre_defined_configs @@ -456,6 +462,7 @@ def __init__( use_full_range: bool = False, use_mse_search: bool = False, use_layer_wise: bool = False, + model_path: str = "", # double quant use_double_quant: bool = False, double_quant_dtype: str = "int", @@ -482,6 +489,7 @@ def __init__( use_full_range (bool): Enables full range for activations, default is False. use_mse_search (bool): Enables mean squared error (MSE) search, default is False. use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + model_path (str): Model path that is used to load state_dict per layer. use_double_quant (bool): Enables double quantization, default is False. double_quant_dtype (str): Data type for double_quant scale, default is "int". double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4. @@ -503,6 +511,7 @@ def __init__( self.use_full_range = use_full_range self.use_mse_search = use_mse_search self.use_layer_wise = use_layer_wise + self.model_path = model_path # double quant self.use_double_quant = use_double_quant self.double_quant_bits = double_quant_bits @@ -529,7 +538,9 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - self.set_local(LM_HEAD_NAMES, AWQConfig(dtype="fp32")) + self.set_local( + LM_HEAD_NAMES, AWQConfig(dtype="fp32", use_layer_wise=self.use_layer_wise, model_path=self.model_path) + ) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping diff --git a/neural_compressor/torch/utils/__init__.py b/neural_compressor/torch/utils/__init__.py index dab02a017c6..25aadaa6d66 100644 --- a/neural_compressor/torch/utils/__init__.py +++ b/neural_compressor/torch/utils/__init__.py @@ -15,3 +15,4 @@ from .environ import * from .constants import * from .utility import * +from neural_compressor.torch.algorithms.layer_wise import load_empty_model diff --git a/requirements_pt.txt b/requirements_pt.txt index 94667b64665..5f18aead98d 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,3 +1,4 @@ +numba numpy < 2.0 peft==0.10.0 prettytable diff --git a/test/3x/torch/algorithms/weight_only/test_woq_module.py b/test/3x/torch/algorithms/weight_only/test_woq_module.py new file mode 100644 index 00000000000..0f06f358beb --- /dev/null +++ b/test/3x/torch/algorithms/weight_only/test_woq_module.py @@ -0,0 +1,52 @@ +import copy + +import pytest +import torch + +from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear +from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor + + +class TestWeightOnlyLinear: + @pytest.mark.parametrize( + "bits, compression_dtype", + [ + (8, torch.int8), + (8, torch.int16), + (8, torch.int32), + (8, torch.int64), + (4, torch.int8), + (4, torch.int16), + (4, torch.int32), + (4, torch.int64), + (2, torch.int8), + (2, torch.int16), + (2, torch.int32), + (2, torch.int64), + ], + ) + def test_pack_with_numba(self, bits, compression_dtype): + m = torch.nn.Linear(64, 32) + dtype = "int" + weight = m.weight.detach() + int_weight, scale, zp = quant_tensor( + weight, + dtype=dtype, + bits=bits, + return_int=True, + group_size=32, + ) + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + dtype=dtype, + bits=bits, + group_size=32, + zp=zp is not None, + bias=m.bias is not None, + use_optimum_format=False, + compression_dtype=compression_dtype, + ) + new_module.pack(int_weight, scale, zp, m.bias) + unpacked_int_weight = new_module.unpack_tensor(new_module.qweight) + assert torch.equal(unpacked_int_weight, int_weight) diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index dfbc39c25e7..8608e1801a4 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -175,13 +175,24 @@ def test_act_order(self): # compare atol, this case is an ideal case. assert atol_false > atol_true, "act_order=True doesn't help accuracy, maybe is reasonable, please double check." - # def test_layer_wise(self): - # model = copy.deepcopy(self.tiny_gptj) - # quant_config = GPTQConfig( - # use_layer_wise=True, - # ) - # model = quantize(model, quant_config, run_fn=run_fn) - # TODO: (Xin) not implemented + def test_layer_wise(self): + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig() + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + q_label = model(self.example_inputs)[0] + + from neural_compressor.torch.utils import load_empty_model + + model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM") + + quant_config = GPTQConfig(use_layer_wise=True, model_path="hf-internal-testing/tiny-random-GPTJForCausalLM") + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + out = model(self.example_inputs)[0] + assert torch.equal(out, q_label), "use_layer_wise=True output should be same. Please double check." @pytest.mark.parametrize("dtype", ["nf4", "int4"]) @pytest.mark.parametrize("double_quant_bits", [6]) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 04f6c444485..cc4a0df6172 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -22,8 +22,8 @@ class ModelConv1d(torch.nn.Module): def __init__(self): super(ModelConv1d, self).__init__() - self.fc1 = transformers.Conv1D(50, 32) - self.fc2 = torch.nn.Linear(50, 32) + self.fc1 = transformers.Conv1D(64, 32) + self.fc2 = torch.nn.Linear(64, 32) self.fc3 = torch.nn.Linear(32, 5) def forward(self, x): @@ -44,7 +44,7 @@ def setup_class(self): self.label = self.tiny_gptj(self.example_inputs)[0] # test_default_config model = copy.deepcopy(self.tiny_gptj) - quant_config = get_default_rtn_config() + quant_config = get_default_rtn_config("Server") model = prepare(model, quant_config) model = convert(model) # record q_label for comparison @@ -167,13 +167,16 @@ def test_quant_lm_head(self): ), "The tied lm_head weight is not deep copied, please check!" def test_layer_wise(self): - model = copy.deepcopy(self.tiny_gptj) + from neural_compressor.torch.utils import load_empty_model + + model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM") quant_config = RTNConfig( use_layer_wise=True, ) model = prepare(model, quant_config) model = convert(model) - # TODO: (Xin) not implemented + out = model(self.example_inputs)[0] + assert torch.equal(out, self.q_label), "use_layer_wise=True output should be same. Please double check." @pytest.mark.parametrize( "dtype",