From 3b2d6822fa04b0190149422e19722adea253ed21 Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:37:48 +0800 Subject: [PATCH] Add Google style docstrings to HQQ files --- .../scripts/codeScan/pydocstyle/scan_path.txt | 1 + .../algorithms/weight_only/hqq/__init__.py | 47 ++++++ .../algorithms/weight_only/hqq/bitpack.py | 126 +++++++++++++- .../algorithms/weight_only/hqq/config.py | 25 +++ .../torch/algorithms/weight_only/hqq/core.py | 155 +++++++++++++++++- .../algorithms/weight_only/hqq/optimizer.py | 16 ++ .../algorithms/weight_only/hqq/qtensor.py | 95 +++++++++++ .../algorithms/weight_only/hqq/quantizer.py | 111 ++++++++++++- .../algorithms/weight_only/hqq/utility.py | 18 +- 9 files changed, 582 insertions(+), 12 deletions(-) diff --git a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt index b524f1f61db..45af9e22e09 100644 --- a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt +++ b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt @@ -15,3 +15,4 @@ /neural-compressor/neural_compressor/strategy /neural-compressor/neural_compressor/training.py /neural-compressor/neural_compressor/utils +/neural-compressor/neural_compressor/torch/algorithms/weight_only/hqq diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py index b11b6095066..8fca2c8c3d7 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py @@ -14,3 +14,50 @@ from .quantizer import HQQuantizer from .config import HQQModuleConfig, QTensorConfig + +class HQQuantizer: + """ + A class for quantizing models using the HQQ algorithm. + + Attributes: + quant_config (ConfigMappingType): Configuration for quantization. + + Methods: + prepare(model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: + Prepares a given model for quantization. + convert(model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: + Converts a prepared model to a quantized model. + save(model, path): + Saves the quantized model to the specified path. + """ + +class HQQModuleConfig: + """ + Configuration for HQQ modules. + + Attributes: + weight (QTensorConfig): Configuration for weight quantization. + scale (QTensorConfig): Configuration for scale quantization. + zero (QTensorConfig): Configuration for zero quantization. + + Methods: + __repr__() -> str: + Returns a string representation of the HQQModuleConfig object. + """ + +class QTensorConfig: + """ + Configuration for quantized tensors. + + Attributes: + nbits (int): Number of bits for quantization. + channel_wise (bool): Whether to use channel-wise quantization. + group_size (int): Size of the quantization group. + optimize (bool): Whether to optimize the quantization. + round_zero (Optional[bool]): Whether to round zero. + pack (bool): Whether to pack the quantized tensor. + + Methods: + __repr__() -> str: + Returns a string representation of the QTensorConfig object. + """ diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py index 5500201a4ee..acf4470563b 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py @@ -30,27 +30,79 @@ # Bit packing logic. format: pack/unpack_nBits_target- class BitPack: + """ + A class for bit packing logic. + + This class provides static methods for packing and unpacking tensors + with different bit-widths. + + Methods: + pack_8bit_u8(W_q): Packs an 8-bit tensor to uint8. + unpack_8bit_u8(W_q): Unpacks an 8-bit tensor from uint8. + pack_4bit_u8(W_q): Packs a 4-bit tensor to uint8. + unpack_4bit_u8(W_q): Unpacks a 4-bit tensor from uint8. + pack_2bit_u8(W_q): Packs a 2-bit tensor to uint8. + unpack_2bit_u8(W_q): Unpacks a 2-bit tensor from uint8. + pack_3bit_32(W_q_in): Packs a 3-bit tensor to int32. + unpack_3bit_32(W_q): Unpacks a 3-bit tensor from int32. + """ + # 8-bit ################################################ @staticmethod def pack_8bit_u8(W_q): + """ + Packs an 8-bit tensor to uint8. + + Args: + W_q (torch.Tensor): The tensor to be packed. + + Returns: + torch.Tensor: The packed tensor. + """ return W_q.to(torch.uint8) @staticmethod def unpack_8bit_u8(W_q): + """ + Unpacks an 8-bit tensor from uint8. + + Args: + W_q (torch.Tensor): The tensor to be unpacked. + + Returns: + torch.Tensor: The unpacked tensor. + """ return W_q # 4-bit ################################################ @staticmethod def pack_4bit_u8(W_q): # uint8 > uint8/2 + """ + Packs a 4-bit tensor to uint8. + + Args: + W_q (torch.Tensor): The tensor to be packed. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = W_q.to(torch.uint8) _step = int(len(W_q) / 2) return (W_q[:_step] << 4) | W_q[_step:] - # A bit faster than the _cat version @staticmethod def unpack_4bit_u8(W_q): # uint8/2 > uint8 + """ + Unpacks a 4-bit tensor from uint8. + + Args: + W_q (torch.Tensor): The tensor to be unpacked. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b11110000) >> 4 @@ -61,13 +113,30 @@ def unpack_4bit_u8(W_q): # uint8/2 > uint8 ################################################ @staticmethod def pack_2bit_u8(W_q): # uint8 > uint8/4 + """ + Packs a 2-bit tensor to uint8. + + Args: + W_q (torch.Tensor): The tensor to be packed. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = W_q.to(torch.uint8) _step = int(len(W_q) / 4) return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :] - # A bit faster than the _cat version @staticmethod def unpack_2bit_u8(W_q): + """ + Unpacks a 2-bit tensor from uint8. + + Args: + W_q (torch.Tensor): The tensor to be unpacked. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b11000000) >> 6 @@ -80,6 +149,15 @@ def unpack_2bit_u8(W_q): ################################################ @staticmethod def pack_3bit_32(W_q_in): + """ + Packs a 3-bit tensor to int32. + + Args: + W_q_in (torch.Tensor): The tensor to be packed. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = torch.zeros( [int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32 ) @@ -99,9 +177,17 @@ def pack_3bit_32(W_q_in): ) return W_q - # A bit faster than _cat version @staticmethod def unpack_3bit_32(W_q): + """ + Unpacks a 3-bit tensor from int32. + + Args: + W_q (torch.Tensor): The tensor to be unpacked. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27 @@ -118,6 +204,22 @@ def unpack_3bit_32(W_q): class Packer: + """ + A class for managing bit packing functions. + + This class provides methods to get the appropriate packing and unpacking + functions based on the number of bits. + + Attributes: + bit_to_packing (dict): A mapping from bit-width to packing format. + pack_fn_mapping (dict): A mapping from packing format to packing function. + unpack_fn_mapping (dict): A mapping from packing format to unpacking function. + + Methods: + get_pack_fn(nbits): Returns the packing function for the given bit-width. + get_unpack_fn(nbits): Returns the unpacking function for the given bit-width. + """ + # TODO: Refine the packer bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"} @@ -137,8 +239,26 @@ class Packer: @staticmethod def get_pack_fn(nbits: int): + """ + Returns the packing function for the given bit-width. + + Args: + nbits (int): The bit-width. + + Returns: + function: The packing function. + """ return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]] @staticmethod def get_unpack_fn(nbits: int): + """ + Returns the unpacking function for the given bit-width. + + Args: + nbits (int): The bit-width. + + Returns: + function: The unpacking function. + """ return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]] diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/config.py b/neural_compressor/torch/algorithms/weight_only/hqq/config.py index a0ee29a22d7..40526699ac4 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/config.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/config.py @@ -33,6 +33,12 @@ class HQQGlobalOptions: + """ + Global options for HQQ. + + Attributes: + use_half (bool): Whether to use half precision. + """ use_half = os.getenv("HQQ_NOT_USE_HALF", "0") == "0" @@ -41,6 +47,17 @@ class HQQGlobalOptions: @dataclass class QTensorConfig: + """ + Configuration for quantized tensors. + + Attributes: + nbits (int): Number of bits for quantization. + channel_wise (bool): Whether to use channel-wise quantization. + group_size (int): Size of the quantization group. + optimize (bool): Whether to optimize the quantization. + round_zero (Optional[bool]): Whether to round zero. + pack (bool): Whether to pack the quantized tensor. + """ nbits: int channel_wise: bool = True group_size: int = 128 @@ -67,6 +84,14 @@ class HQQModuleConfig( ["weight", "scale", "zero"], ) ): + """ + Configuration for HQQ modules. + + Attributes: + weight (QTensorConfig): Configuration for weight quantization. + scale (QTensorConfig): Configuration for scale quantization. + zero (QTensorConfig): Configuration for zero quantization. + """ def __new__( cls, weight=default_weight_quant_config, diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/core.py b/neural_compressor/torch/algorithms/weight_only/hqq/core.py index 041e173671d..e2662180cd9 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/core.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/core.py @@ -39,6 +39,26 @@ class HQQTensorHandle: + """ + A class for handling quantized tensors using the HQQ algorithm. + + Attributes: + SUPPORTED_BITS (list): List of supported bit-widths for quantization. + optimize_weights (function): Function for optimizing weights. + + Methods: + quantize(float_tensor, tensor_quant_config=None): + Quantizes a given float tensor. + dequantize(q_weight): + Dequantizes a given quantized tensor. + _create_q_tensor(weight, meta): + Creates a QTensor object from the given weight and meta information. + _quantize(tensor, tensor_quant_config=None): + Internal method for quantizing a tensor. + _dequantize(W_q, meta): + Internal method for dequantizing a quantized tensor. + """ + # Refactored the code from https://github.com/mobiusml/hqq. # Store meta-data (we invert the scale for dequantization) @@ -47,6 +67,16 @@ class HQQTensorHandle: @classmethod def quantize(cls, float_tensor, tensor_quant_config: QTensorConfig = None): + """ + Quantizes a given float tensor. + + Args: + float_tensor (torch.Tensor): The float tensor to be quantized. + tensor_quant_config (QTensorConfig, optional): Configuration for tensor quantization. Defaults to None. + + Returns: + QTensor: The quantized tensor. + """ q_weight, q_tensor_meta = cls._quantize( tensor=float_tensor, tensor_quant_config=tensor_quant_config, @@ -56,6 +86,15 @@ def quantize(cls, float_tensor, tensor_quant_config: QTensorConfig = None): @classmethod def dequantize(cls, q_weight: "QTensor") -> torch.Tensor: + """ + Dequantizes a given quantized tensor. + + Args: + q_weight (QTensor): The quantized tensor to be dequantized. + + Returns: + torch.Tensor: The dequantized float tensor. + """ # Dequantized the Qtensor into float tensor meta = q_weight.meta_info.to_dict() meta["zero"] = q_weight.zero @@ -64,6 +103,16 @@ def dequantize(cls, q_weight: "QTensor") -> torch.Tensor: @classmethod def _create_q_tensor(cls, weight, meta) -> "QTensor": + """ + Creates a QTensor object from the given weight and meta information. + + Args: + weight (torch.Tensor): The quantized weight tensor. + meta (dict): Meta information for the quantized tensor. + + Returns: + QTensor: The created QTensor object. + """ scale = meta["scale"] zero = meta["zero"] meta_info = QTensorMetaInfo( @@ -77,6 +126,16 @@ def _create_q_tensor(cls, weight, meta) -> "QTensor": @classmethod def _quantize(cls, tensor, tensor_quant_config: QTensorConfig = None): + """ + Internal method for quantizing a tensor. + + Args: + tensor (torch.Tensor): The tensor to be quantized. + tensor_quant_config (QTensorConfig, optional): Configuration for tensor quantization. Defaults to None. + + Returns: + Tuple[torch.Tensor, dict]: The quantized tensor and its meta information. + """ nbits = tensor_quant_config.nbits channel_wise = tensor_quant_config.channel_wise group_size = tensor_quant_config.group_size if tensor_quant_config.group_size != -1 else None @@ -160,8 +219,18 @@ def _quantize(cls, tensor, tensor_quant_config: QTensorConfig = None): @classmethod def _dequantize(cls, W_q, meta): + """ + Internal method for dequantizing a quantized tensor. + + Args: + W_q (torch.Tensor): The quantized tensor. + meta (dict): Meta information for the quantized tensor. + + Returns: + torch.Tensor: The dequantized float tensor. + """ # Main dequantization: bit_unpacking > (W_q - z)*s > reshape - if meta["packing"]: + if (meta["packing"]): W_r = Packer.get_unpack_fn(meta["nbits"])(W_q) if hqq_global_option.use_half: W_r = W_r.half() @@ -176,6 +245,29 @@ def _dequantize(cls, W_q, meta): class HQQLinear(torch.nn.Linear): + """ + A class for quantizing linear layers using the HQQ algorithm. + + Attributes: + q_weight (QTensor): The quantized weight tensor. + quantized (bool): Whether the weight has been quantized. + + Methods: + quantize_weight(W, quant_config=default_hqq_module_config): + Quantizes the weight tensor. + dequantize_weight(): + Dequantizes the weight tensor. + forward(input): + Performs the forward pass using the quantized weight. + from_float(float_module, quant_config=default_hqq_module_config): + Creates a quantized linear module from a float linear module. + state_dict(*args, **kwargs): + Returns the state dictionary of the module. + _load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + Loads the state dictionary into the module. + _assign_state_dict(state_dict, strict=True, assign=False): + Assigns the state dictionary to the module. + """ def __init__( self, @@ -196,6 +288,16 @@ def quantize_weight( W: torch.Tensor, quant_config: HQQModuleConfig = default_hqq_module_config, ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """ + Quantizes the weight tensor. + + Args: + W (torch.Tensor): The weight tensor to be quantized. + quant_config (HQQModuleConfig, optional): Configuration for module quantization. Defaults to default_hqq_module_config. + + Returns: + Tuple[torch.Tensor, Dict[str, Any]]: The quantized weight tensor and its meta information. + """ weight_quant_config, scale_quant_config, zero_quant_config = ( quant_config.weight, quant_config.scale, @@ -227,6 +329,12 @@ def quantize_weight( self.quantized = True def dequantize_weight(self): + """ + Dequantizes the weight tensor. + + Returns: + torch.Tensor: The dequantized weight tensor. + """ assert self.quantized, "model was not quantized" # TODO: move below logic into `HQQTensorHandle` if self.q_weight.is_scale_quantized(): @@ -241,6 +349,15 @@ def dequantize_weight(self): return W_qdq def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass using the quantized weight. + + Args: + input (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ out = torch.matmul(input, self.dequantize_weight().t()) if self.bias is not None: out += self.bias @@ -252,6 +369,16 @@ def from_float( float_module: torch.nn.Linear, quant_config: HQQModuleConfig = default_hqq_module_config, ): + """ + Creates a quantized linear module from a float linear module. + + Args: + float_module (torch.nn.Linear): The float linear module. + quant_config (HQQModuleConfig, optional): Configuration for module quantization. Defaults to default_hqq_module_config. + + Returns: + HQQLinear: The created quantized linear module. + """ # Create the new module with a toy size to ensure initialization is fast fake_in_features, fake_out_features = 8, 8 new_mod = cls( @@ -280,6 +407,12 @@ def from_float( return new_mod def state_dict(self, *args, **kwargs): # nn.Module override compatible + """ + Returns the state dictionary of the module. + + Returns: + dict: The state dictionary of the module. + """ state_dict = self.q_weight.to_state_dict() if self.bias is not None: state_dict["bias"] = self.bias @@ -298,6 +431,18 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): + """ + Loads the state dictionary into the module. + + Args: + state_dict (dict): The state dictionary to be loaded. + prefix (str): The prefix for the keys in the state dictionary. + local_metadata (dict): Local metadata for the state dictionary. + strict (bool): Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's `state_dict` function. + missing_keys (list): List to store missing keys. + unexpected_keys (list): List to store unexpected keys. + error_msgs (list): List to store error messages. + """ all_expected_keys = ["val", "scale_quantized", "zero_quantized", "meta_info"] if self.bias is not None: all_expected_keys.append("bias") @@ -316,6 +461,14 @@ def _load_from_state_dict( self._assign_state_dict(cur_state_dict, strict) def _assign_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + """ + Assigns the state dictionary to the module. + + Args: + state_dict (dict): The state dictionary to be assigned. + strict (bool, optional): Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's `state_dict` function. Defaults to True. + assign (bool, optional): Whether to assign the state dictionary. Defaults to False. + """ _scale_quantized = state_dict["scale_quantized"] _zero_quantized = state_dict["zero_quantized"] scale_state = state_dict["meta_info"]["scale"] diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py index e471e6c017a..250ac1561cc 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py @@ -35,6 +35,22 @@ def optimize_weights_proximal_legacy( opt_params={"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20}, verbose=False, ): + """ + Optimize weights using a proximal solver. + + Args: + tensor (torch.Tensor): The input tensor to be optimized. + scale (torch.Tensor): The scale tensor for quantization. + zero (torch.Tensor): The zero tensor for quantization. + min_max (tuple): The minimum and maximum values for quantization. + axis (int, optional): The axis along which to optimize. Defaults to 0. + device (str, optional): The device to use for computation. Defaults to "cuda". + opt_params (dict, optional): Optimization parameters. Defaults to {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20}. + verbose (bool, optional): Whether to print verbose output. Defaults to False. + + Returns: + tuple: The optimized scale and zero tensors. + """ lp_norm, beta, kappa, iters = ( opt_params["lp_norm"], opt_params["beta"], diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py index f1fbd5bce3a..e5f9b910639 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py @@ -25,6 +25,16 @@ @dataclass class QTensorMetaInfo: + """ + Meta information for quantized tensors. + + Attributes: + nbits (int): Number of bits for quantization. + group_size (int): Size of the quantization group. + shape (Tuple): Shape of the tensor. + axis (int): Axis for quantization. + packing (bool): Whether the tensor is packed. + """ nbits: int group_size: int shape: Tuple @@ -32,10 +42,43 @@ class QTensorMetaInfo: packing: bool def to_dict(self): + """ + Converts the QTensorMetaInfo object to a dictionary. + + Returns: + dict: A dictionary representation of the QTensorMetaInfo object. + """ return asdict(self) class QTensor: + """ + A class representing a quantized tensor. + + Attributes: + val (torch.Tensor): The quantized tensor values. + scale (Union[torch.Tensor, "QTensor"], optional): The scale tensor or quantized scale tensor. + zero (Union[torch.Tensor, "QTensor"], optional): The zero tensor or quantized zero tensor. + meta_info (QTensorMetaInfo, optional): Meta information for the quantized tensor. + + Methods: + is_scale_quantized() -> bool: + Checks if the scale is quantized. + is_zero_quantized() -> bool: + Checks if the zero is quantized. + _get_scale_repr() -> str: + Returns a string representation of the scale. + _get_zero_repr() -> str: + Returns a string representation of the zero. + __repr__() -> str: + Returns a string representation of the QTensor object. + to(*args, **kwargs): + Moves the tensor to the specified device and dtype. + half(): + Converts the tensor to half precision. + to_state_dict() -> dict: + Converts the QTensor object to a state dictionary. + """ val: torch.Tensor scale: Union[torch.Tensor, "QTensor"] = None zero: Union[torch.Tensor, "QTensor"] = None @@ -57,12 +100,30 @@ def __init__(self, val, scale=None, zero=None, meta_info=None): self.meta_info = meta_info def is_scale_quantized(self) -> bool: + """ + Checks if the scale is quantized. + + Returns: + bool: True if the scale is quantized, False otherwise. + """ return isinstance(self.scale, QTensor) def is_zero_quantized(self) -> bool: + """ + Checks if the zero is quantized. + + Returns: + bool: True if the zero is quantized, False otherwise. + """ return isinstance(self.zero, QTensor) def _get_scale_repr(self) -> str: + """ + Returns a string representation of the scale. + + Returns: + str: A string representation of the scale. + """ if not self.is_scale_quantized(): if self.scale is not None: return ( @@ -76,6 +137,12 @@ def _get_scale_repr(self) -> str: return self.scale.__repr__() + "\n" def _get_zero_repr(self) -> str: + """ + Returns a string representation of the zero. + + Returns: + str: A string representation of the zero. + """ if not self.is_zero_quantized(): if self.zero is not None: return ( @@ -89,6 +156,12 @@ def _get_zero_repr(self) -> str: return self.zero.__repr__() + "\n" def __repr__(self) -> str: + """ + Returns a string representation of the QTensor object. + + Returns: + str: A string representation of the QTensor object. + """ # TODO: refine it later return ( f"QTensor(\n" @@ -101,12 +174,28 @@ def __repr__(self) -> str: ) def to(self, *args, **kwargs): + """ + Moves the tensor to the specified device and dtype. + + Args: + *args: Positional arguments for the `to` method. + **kwargs: Keyword arguments for the `to` method. + + Returns: + QTensor: The QTensor object moved to the specified device and dtype. + """ self.val = self.val.to(*args, **kwargs) self.scale = self.scale.to(*args, **kwargs) self.zero = self.zero.to(*args, **kwargs) return self def half(self): + """ + Converts the tensor to half precision. + + Returns: + QTensor: The QTensor object in half precision. + """ # TODO: refine it later if self.val.dtype == torch.float32: self.val = self.val.half() @@ -117,6 +206,12 @@ def half(self): return self def to_state_dict(self): + """ + Converts the QTensor object to a state dictionary. + + Returns: + dict: A state dictionary representation of the QTensor object. + """ state = {} state["val"] = self.val state["meta_info"] = self.meta_info.to_dict() diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py index 43b1dda1b4a..ff783100f52 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py @@ -25,6 +25,15 @@ def _has_child(module: torch.nn.Module) -> bool: + """ + Check if a module has any child modules. + + Args: + module (torch.nn.Module): The module to check. + + Returns: + bool: True if the module has child modules, False otherwise. + """ return len(list(module.named_children())) > 0 @@ -35,8 +44,16 @@ def _replace_with_custom_fn_if_matches_filter( cur_fqn: str = "", config_mapping: Optional[ConfigMappingType] = None, ) -> None: - """For each `child` in `model`, replaces it with `replacement_fn(child)` - if `filter_fn(child)` is `True`""" + """ + Replace child modules in a model with a custom function if they match a filter function. + + Args: + model (torch.nn.Module): The model containing the child modules. + replacement_fn (Callable): The function to replace the child modules with. + filter_fn (Callable): The function to filter the child modules. + cur_fqn (str, optional): The current fully qualified name of the module. Defaults to "". + config_mapping (Optional[ConfigMappingType], optional): Configuration mapping for the modules. Defaults to None. + """ name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): if cur_fqn == "": @@ -64,44 +81,101 @@ def _replace_with_custom_fn_if_matches_filter( def patch_hqq_moduile(mod, config): + """ + Patch a module with HQQ configuration. + + Args: + mod (torch.nn.Module): The module to be patched. + config (HQQModuleConfig): The HQQ configuration. + + Returns: + torch.nn.Module: The patched module. + """ new_mod = HQQLinear.from_float(mod, config) return new_mod def filter_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> bool: + """ + Filter function to check if a module should be replaced. + + Args: + mod (torch.nn.Module): The module to check. + name (str): The name of the module. + config_mapping (ConfigMappingType): Configuration mapping for the modules. + + Returns: + bool: True if the module should be replaced, False otherwise. + """ return isinstance(mod, torch.nn.Linear) and name in config_mapping def replacement_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> torch.nn.Module: + """ + Replacement function to replace a module with a patched module. + + Args: + mod (torch.nn.Module): The module to be replaced. + name (str): The name of the module. + config_mapping (ConfigMappingType): Configuration mapping for the modules. + + Returns: + torch.nn.Module: The replaced module. + """ config = config_mapping.get(name, None) logger.debug("Replace module %s", name) return patch_hqq_moduile(mod, config) class HQQuantizer(Quantizer): + """ + A class for quantizing models using the HQQ algorithm. + + Attributes: + quant_config (ConfigMappingType): Configuration for quantization. + + Methods: + prepare(model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: + Prepares a given model for quantization. + convert(model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: + Converts a prepared model to a quantized model. + save(model, path): + Saves the quantized model to the specified path. + _convert_hqq_module_config(config) -> HQQModuleConfig: + Converts a configuration to HQQModuleConfig. + _parse_hqq_configs_mapping(configs_mapping): + Parses HQQ configuration mappings. + """ + def __init__(self, quant_config: ConfigMappingType) -> None: - """Init a HQQuantizer object. + """ + Init a HQQuantizer object. Args: - quant_config (ConfigMappingType): quantization config for ops. + quant_config (ConfigMappingType): Configuration for quantization. """ quant_config = self._parse_hqq_configs_mapping(quant_config) super().__init__(quant_config=quant_config) @torch.no_grad() def prepare(self, model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: - """Prepares a given model for quantization. + """ + Prepares a given model for quantization. Will return model directly in HQQ algorithm. Args: model (torch.nn.Module): The model to be prepared. + + Returns: + Optional[torch.nn.Module]: The prepared model. """ return model @torch.no_grad() def convert(self, model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.Module]: - """Converts a prepared model to a quantized model. + """ + Converts a prepared model to a quantized model. Args: model (torch.nn.Module): The prepared model to be converted. @@ -115,10 +189,26 @@ def convert(self, model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn. return model def save(self, model, path): + """ + Saves the quantized model to the specified path. + + Args: + model (torch.nn.Module): The quantized model to be saved. + path (str): The path to save the model. + """ # TODO: to implement it in the next PR pass def _convert_hqq_module_config(self, config) -> HQQModuleConfig: + """ + Converts a configuration to HQQModuleConfig. + + Args: + config: The configuration to be converted. + + Returns: + HQQModuleConfig: The converted HQQModuleConfig. + """ # TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm. # * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py nbits = config.bits @@ -145,6 +235,15 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig: return hqq_module_config def _parse_hqq_configs_mapping(self, configs_mapping): + """ + Parses HQQ configuration mappings. + + Args: + configs_mapping: The configuration mappings to be parsed. + + Returns: + dict: The parsed configuration mappings. + """ qconfig_mapping = {} for (op_name, op_type), quant_config in configs_mapping.items(): if quant_config is not None and quant_config.dtype == "fp32": diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/utility.py b/neural_compressor/torch/algorithms/weight_only/hqq/utility.py index 9c9b3700cf6..6d3fe833f2a 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/utility.py @@ -29,6 +29,16 @@ def is_divisible(val1, val2): + """ + Check if val1 is divisible by val2. + + Args: + val1 (int): The dividend. + val2 (int): The divisor. + + Returns: + bool: True if val1 is divisible by val2, False otherwise. + """ return int(val2 * np.ceil(val1 / val2)) == val1 @@ -54,10 +64,14 @@ def see_cuda_memory_usage(message, force=False): # pragma: no cover def dump_elapsed_time(customized_msg=""): - """Get the elapsed time for decorated functions. + """ + Decorator to measure and log the elapsed time for a function. Args: - customized_msg (string, optional): The parameter passed to decorator. Defaults to None. + customized_msg (str, optional): Custom message to include in the log. Defaults to "". + + Returns: + function: The decorated function with elapsed time logging. """ def f(func):