diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index b5a6e44ab54..c23bde7de64 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -15,5 +15,6 @@ from .rtn import rtn_quantize from .gptq import gptq_quantize from .awq import awq_quantize +from .hqq import hqq_quantize from .modules import WeightOnlyLinear from .utility import * diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py new file mode 100644 index 00000000000..7acb79401a2 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantizer import HQQuantizer +from .config import HQQModuleConfig, QTensorConfig +from .quant_api import hqq_quantize diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/auto_accelerator.py b/neural_compressor/torch/algorithms/weight_only/hqq/auto_accelerator.py new file mode 100644 index 00000000000..0026f76936e --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/auto_accelerator.py @@ -0,0 +1,223 @@ +# Copyright (c) 2023-2024 Microsoft Corporation and Intel Corporation + +# This code is based on Microsoft Corporation's DeepSpeed library and +# the accelerators implementation in this library. It has been modified +# from its original forms to simplify and adapt it for use in +# the Intel® Neural Compressor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTICE: The design adapted from: +# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py. +# TODO: move it into torch/utils + + +# To keep it simply, only add the APIs we need. + +import os +from abc import ABC, abstractmethod +from typing import Any, Callable, List + +import torch + +from neural_compressor.torch.utils import logger + +PRIORITY_CUDA = 100 +PRIORITY_CPU = 90 + + +class AcceleratorRegistry: + registered_accelerators = {} + + @classmethod + def register_accelerator_impl(cls, name: str, priority: float = 0): + """Register new accelerator implementation. + + Usage example: + @AcceleratorRegistry.register_accelerator(name="cpu", priority=100) + class CPU_Accelerator: + ... + + Args: + name: the accelerator name. + priority: priority: the priority of the accelerator. A larger number indicates a higher priority, + """ + + def decorator(accelerator_cls): + cls.registered_accelerators.setdefault(name, {}) + cls.registered_accelerators[name] = (accelerator_cls, priority) + return accelerator_cls + + return decorator + + @classmethod + def get_sorted_accelerators(cls) -> List["Auto_Accelerator"]: + """Get registered accelerators sorted by priority.""" + accelerator_pairs = cls.registered_accelerators.values() + sorted_accelerators_pairs = sorted(accelerator_pairs, key=lambda x: x[1], reverse=True) + sorted_accelerators = [pair[0] for pair in sorted_accelerators_pairs] + return sorted_accelerators + + @classmethod + def get_accelerator_cls_by_name(cls, name: str) -> "Auto_Accelerator": + """Get accelerator by name.""" + accelerator_cls, _ = cls.registered_accelerators.get(name, (None, None)) + return accelerator_cls + + +accelerator_registry = AcceleratorRegistry() + + +def register_accelerator(name: str, priority: float = 0) -> Callable[..., Any]: + """Register new accelerator. + + Usage example: + @register_accelerator(name="cuda", priority=100) + class CUDA_Accelerator: + ... + + Args: + name: the accelerator name. + priority: the priority of the accelerator. A larger number indicates a higher priority, + """ + + return accelerator_registry.register_accelerator_impl(name=name, priority=priority) + + +class Auto_Accelerator(ABC): + @classmethod + @abstractmethod + def is_available(cls) -> bool: + pass + + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def device_name(self, device_indx) -> str: + pass + + @abstractmethod + def set_device(self, device_index): + pass + + @abstractmethod + def current_device(self): + pass + + @abstractmethod + def current_device_name(self): + pass + + @abstractmethod + def device(self, device_index=None): + pass + + @abstractmethod + def empty_cache(self): + pass + + @abstractmethod + def synchronize(self): + pass + + +@register_accelerator(name="cpu", priority=PRIORITY_CPU) +class CPU_Accelerator(Auto_Accelerator): + def __init__(self) -> None: + self._name = "cpu" + + def name(self) -> str: + return self._name + + @classmethod + def is_available(cls) -> bool: + return True + + def device_name(self, device_indx) -> str: + return "cpu" + + def set_device(self, device_index): + pass + + def current_device(self): + return "cpu" + + def current_device_name(self): + return "cpu" + + def device(self, device_index=None): + pass + + def empty_cache(self): + pass + + def synchronize(self): + pass + + +@register_accelerator(name="cuda", priority=PRIORITY_CUDA) +class CUDA_Accelerator(Auto_Accelerator): + def __init__(self) -> None: + self._name = "cuda" + + def name(self) -> str: + return self._name + + @classmethod + def is_available(cls) -> bool: + return torch.cuda.is_available() + + def device_name(self, device_indx) -> str: + if device_indx is None: + return "cuda" + return f"cuda:{device_indx}" + + def synchronize(self): + return torch.cuda.synchronize() + + def set_device(self, device_index): + return torch.cuda.set_device(device_index) + + def current_device(self): + return torch.cuda.current_device() + + def current_device_name(self): + return "cuda:{}".format(torch.cuda.current_device()) + + def device(self, device_index=None): + return torch.cuda.device(device_index) + + def empty_cache(self): + return torch.cuda.empty_cache() + + +def auto_detect_accelerator() -> Auto_Accelerator: + # if runtime_accelerator.accelerator: + # return runtime_accelerator.accelerator + FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None) + if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None: + logger.warning("Force use %s accelerator.", FORCE_DEVICE) + return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)() + for accelerator_cls in accelerator_registry.get_sorted_accelerators(): + if accelerator_cls.is_available(): + logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__) + accelerator = accelerator_cls() + return accelerator + + +# Force use cpu accelerator even if cuda is available. +# FORCE_DEVICE = "cpu" python ... +# or +# CUDA_VISIBLE_DEVICES="" python ... diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py new file mode 100644 index 00000000000..5500201a4ee --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023-2024 Mobiusml and Intel Corporation + +# This code is based on Mobiusml's HQQ library. It has been modified +# from its original forms to simplify and adapt it for use in +# the Intel® Neural Compressor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Notice: Copied from from https://github.com/mobiusml/hqq +# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 +##################################################### + +import numpy as np +import torch + +from .utility import is_divisible + +__all__ = ["Packer"] + + +# Bit packing logic. format: pack/unpack_nBits_target- +class BitPack: + # 8-bit + ################################################ + @staticmethod + def pack_8bit_u8(W_q): + return W_q.to(torch.uint8) + + @staticmethod + def unpack_8bit_u8(W_q): + return W_q + + # 4-bit + ################################################ + @staticmethod + def pack_4bit_u8(W_q): # uint8 > uint8/2 + W_q = W_q.to(torch.uint8) + _step = int(len(W_q) / 2) + return (W_q[:_step] << 4) | W_q[_step:] + + # A bit faster than the _cat version + @staticmethod + def unpack_4bit_u8(W_q): # uint8/2 > uint8 + _step = W_q.shape[0] + tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) + tmp[:_step] = (W_q & 0b11110000) >> 4 + tmp[_step:] = W_q & 0b00001111 + return tmp + + # 2-bit + ################################################ + @staticmethod + def pack_2bit_u8(W_q): # uint8 > uint8/4 + W_q = W_q.to(torch.uint8) + _step = int(len(W_q) / 4) + return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :] + + # A bit faster than the _cat version + @staticmethod + def unpack_2bit_u8(W_q): + _step = W_q.shape[0] + tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) + tmp[:_step] = (W_q & 0b11000000) >> 6 + tmp[_step : 2 * _step] = (W_q & 0b00110000) >> 4 + tmp[2 * _step : 3 * _step] = (W_q & 0b00001100) >> 2 + tmp[3 * _step :] = W_q & 0b00000011 + return tmp + + # 3bit + ################################################ + @staticmethod + def pack_3bit_32(W_q_in): + W_q = torch.zeros( + [int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32 + ) + W_q[: len(W_q_in)] = W_q_in + _step = int(len(W_q) / 10) + W_q = ( + (W_q[:_step] << 27) + | (W_q[_step : _step * 2] << 24) + | (W_q[_step * 2 : _step * 3] << 21) + | (W_q[_step * 3 : _step * 4] << 18) + | (W_q[_step * 4 : _step * 5] << 15) + | (W_q[_step * 5 : _step * 6] << 12) + | (W_q[_step * 6 : _step * 7] << 9) + | (W_q[7 * _step : _step * 8] << 6) + | (W_q[_step * 8 : _step * 9] << 3) + | (W_q[_step * 9 :]) + ) + return W_q + + # A bit faster than _cat version + @staticmethod + def unpack_3bit_32(W_q): + _step = W_q.shape[0] + tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) + tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27 + tmp[1 * _step : 2 * _step] = (W_q & 0b00000111000000000000000000000000) >> 24 + tmp[2 * _step : 3 * _step] = (W_q & 0b00000000111000000000000000000000) >> 21 + tmp[3 * _step : 4 * _step] = (W_q & 0b00000000000111000000000000000000) >> 18 + tmp[4 * _step : 5 * _step] = (W_q & 0b00000000000000111000000000000000) >> 15 + tmp[5 * _step : 6 * _step] = (W_q & 0b00000000000000000111000000000000) >> 12 + tmp[6 * _step : 7 * _step] = (W_q & 0b00000000000000000000111000000000) >> 9 + tmp[7 * _step : 8 * _step] = (W_q & 0b00000000000000000000000111000000) >> 6 + tmp[8 * _step : 9 * _step] = (W_q & 0b00000000000000000000000000111000) >> 3 + tmp[9 * _step :] = W_q & 0b00000000000000000000000000000111 + return tmp + + +class Packer: + # TODO: Refine the packer + bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"} + + pack_fn_mapping = { + "8bit_u8": BitPack.pack_8bit_u8, + "4bit_u8": BitPack.pack_4bit_u8, + "3bit_32": BitPack.pack_3bit_32, + "2bit_u8": BitPack.pack_2bit_u8, + } + + unpack_fn_mapping = { + "8bit_u8": BitPack.unpack_8bit_u8, + "4bit_u8": BitPack.unpack_4bit_u8, + "3bit_32": BitPack.unpack_3bit_32, + "2bit_u8": BitPack.unpack_2bit_u8, + } + + @staticmethod + def get_pack_fn(nbits: int): + return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]] + + @staticmethod + def get_unpack_fn(nbits: int): + return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]] diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/config.py b/neural_compressor/torch/algorithms/weight_only/hqq/config.py new file mode 100644 index 00000000000..a0ee29a22d7 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/config.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from collections import namedtuple +from dataclasses import dataclass +from typing import Dict, Optional + +from typing_extensions import TypeAlias + +__all__ = [ + "ConfigMappingType", + "HQQModuleConfig", + "QTensorConfig", + "hqq_global_option", + "default_hqq_module_config", + "default_weight_quant_config", + "default_scale_quant_config", + "default_zero_quant_config", +] + + +class HQQGlobalOptions: + use_half = os.getenv("HQQ_NOT_USE_HALF", "0") == "0" + + +hqq_global_option = HQQGlobalOptions() + + +@dataclass +class QTensorConfig: + nbits: int + channel_wise: bool = True + group_size: int = 128 + optimize: bool = True + round_zero: Optional[bool] = False + pack: bool = True + + def __repr__(self) -> str: + return ( + f"QTensorConfig(nbits={self.nbits}, channel_wise={self.channel_wise}, " + f"group_size={self.group_size}, optimize={self.optimize}, " + f"round_zero={self.round_zero}, pack={self.pack})" + ) + + +default_weight_quant_config = QTensorConfig(nbits=4, channel_wise=True, group_size=128, optimize=True, round_zero=True) +default_scale_quant_config = QTensorConfig(nbits=8, channel_wise=True, group_size=64, optimize=False, round_zero=None) +default_zero_quant_config = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False, round_zero=None) + + +class HQQModuleConfig( + namedtuple( + "HQQModuleConfig", + ["weight", "scale", "zero"], + ) +): + def __new__( + cls, + weight=default_weight_quant_config, + scale=default_scale_quant_config, + zero=default_zero_quant_config, + ): + return super().__new__(cls, weight, scale, zero) + + def __repr__(self) -> str: + return ( + f"HQQModuleConfig(\n" f" weight={self.weight},\n" f" scale={self.scale},\n" f" zero={self.zero}\n)" + ) + + +default_hqq_module_config = HQQModuleConfig( + weight=default_weight_quant_config, + scale=default_scale_quant_config, + zero=default_zero_quant_config, +) + + +ConfigMappingType: TypeAlias = Dict[str, HQQModuleConfig] diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/core.py b/neural_compressor/torch/algorithms/weight_only/hqq/core.py new file mode 100644 index 00000000000..21d9eb0027a --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/core.py @@ -0,0 +1,280 @@ +# Copyright (c) 2023-2024 Mobiusml and Intel Corporation + +# This code is based on Mobiusml's HQQ library. It has been modified +# from its original forms to simplify and adapt it for use in +# the Intel® Neural Compressor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle` +# and `QTensor` to decouple the data structure and the quantization logic. + +from typing import Any, Dict, Tuple + +import torch + +from neural_compressor.torch.utils import logger + +from .auto_accelerator import auto_detect_accelerator +from .bitpack import Packer +from .config import HQQModuleConfig, QTensorConfig, default_hqq_module_config, hqq_global_option +from .optimizer import optimize_weights_proximal +from .qtensor import QTensor, QTensorMetaInfo +from .utility import dump_elapsed_time, is_divisible + +__all__ = [ + "HQQTensorHandle", + "HQQLinear", +] + + +class HQQTensorHandle: + # Refactored the code from https://github.com/mobiusml/hqq. + + # Store meta-data (we invert the scale for dequantization) + SUPPORTED_BITS = [8, 4, 3, 2] + optimize_weights = optimize_weights_proximal + + @classmethod + def quantize(cls, float_tensor, tensor_quant_config: QTensorConfig = None): + q_weight, q_tensor_meta = cls._quantize( + tensor=float_tensor, + tensor_quant_config=tensor_quant_config, + ) + q_weight = cls._create_q_tensor(q_weight, q_tensor_meta) + return q_weight + + @classmethod + def dequantize(cls, q_weight: "QTensor") -> torch.Tensor: + # Dequantized the Qtensor into float tensor + meta = q_weight.meta_info.to_dict() + meta["zero"] = q_weight.zero + meta["scale"] = q_weight.scale + return cls._dequantize(q_weight.val, meta) + + @classmethod + def _create_q_tensor(cls, weight, meta) -> "QTensor": + scale = meta["scale"] + zero = meta["zero"] + meta_info = QTensorMetaInfo( + nbits=meta["nbits"], + group_size=meta["group_size"], + shape=meta["shape"], + axis=meta["axis"], + packing=meta["packing"], + ) + return QTensor(weight, scale, zero, meta_info) + + @classmethod + def _quantize(cls, tensor, tensor_quant_config: QTensorConfig = None): + nbits = tensor_quant_config.nbits + channel_wise = tensor_quant_config.channel_wise + group_size = tensor_quant_config.group_size + optimize = tensor_quant_config.optimize + round_zero = tensor_quant_config.round_zero + axis = 0 # *Note did not exposed to the user + bitpack = tensor_quant_config.pack + + assert nbits in cls.SUPPORTED_BITS, "nbits=" + str(nbits) + " not supported." + assert axis in [0, 1], "axis should be either 0 or 1, but got {}".format(axis) + if group_size is not None: + assert is_divisible(tensor.numel(), group_size), ( + "group_size should be divisible by the total tensor dimensions. shape: " + + str(tensor.shape) + + ", group_size: " + + str(group_size) + ) + + W = tensor.float() + shape = W.shape + + # Reshape for grouping + if (group_size is not None) and channel_wise: + W = W.reshape([-1, group_size]) if (axis == 1) else W.reshape([group_size, -1]) + + # Get min/max values + if not channel_wise: + _min, _max = W.min(), W.max() + optimize = False + else: + _min = W.min(axis=axis, keepdim=True)[0] + _max = W.max(axis=axis, keepdim=True)[0] + + max_v = 2**nbits - 1 + min_v = 0 + min_max = [min_v, max_v] + + # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, + # the scale is inverted later on. + scale = (max_v / (_max - _min)).clamp(max=2e4) # clamp to avoid half-precision problems + zero = -_min * scale + + # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14 + if round_zero: + zero = torch.round(zero) + + # Fine-tune weights + if optimize: + scale, zero = cls.optimize_weights(tensor=W, scale=scale, zero=zero, min_max=min_max, axis=axis) + + # Quantize + scale, zero = ( + scale.clone(), + zero.clone(), + ) # Necessary for fake quantization backprop + W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1]) + + # Store meta-data (we invert the scale for dequantization) + meta = { + "nbits": nbits, + "group_size": group_size, + "shape": shape, + "scale": 1.0 / scale, + "zero": zero, + "axis": axis, + "packing": bitpack, + } + + # Pack bits + if bitpack: + W_q = Packer.get_pack_fn(meta["nbits"])(W_q) + else: + W_q = W_q.to(tensor.dtype) + meta["packing"] = None + + # cleanup + del W, _min, _max + auto_detect_accelerator().empty_cache() + + return W_q, meta + + @classmethod + def _dequantize(cls, W_q, meta): + # Main dequantization: bit_unpacking > (W_q - z)*s > reshape + if meta["packing"]: + W_r = Packer.get_unpack_fn(meta["nbits"])(W_q) + if hqq_global_option.use_half: + W_r = W_r.half() + if (meta["group_size"] is not None) and (meta["nbits"] == 3): + W_r = W_r[: meta["group_size"]] if (meta["axis"] == 0) else W_r[:, : meta["group_size"]] + else: + if hqq_global_option.use_half: + W_r = W_q.half() + # TODO: double check the correctness, the official impl is also error... + W_r = ((W_r - meta["zero"]) * meta["scale"]).reshape(meta["shape"]) + return W_r + + +class HQQLinear(torch.nn.Linear): + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + q_weight: QTensor = None, + device=None, + dtype=None, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + self.q_weight = q_weight + self.quantized = q_weight is not None + + @dump_elapsed_time("Quantize linear module into HQQ module.") + def quantize_weight( + self, + W: torch.Tensor, + quant_config: HQQModuleConfig = default_hqq_module_config, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + weight_quant_config, scale_quant_config, zero_quant_config = ( + quant_config.weight, + quant_config.scale, + quant_config.zero, + ) + need_quant_scale = scale_quant_config is not None + need_quant_zero = zero_quant_config is not None + + self.in_features, self.out_features = W.t().shape + + # Quantize weight + q_weight = HQQTensorHandle.quantize(float_tensor=W, tensor_quant_config=weight_quant_config) + self.q_weight = q_weight + + # * The dequantization process only happens in the first forward pass. + # * It will change the `q_weight` but faster. + # * we should not save the state after doing the forward. + if need_quant_scale: # Quantize scale + q_scale_tensor = HQQTensorHandle.quantize( + float_tensor=self.q_weight.scale, tensor_quant_config=scale_quant_config + ) + self.q_weight.scale = q_scale_tensor + if need_quant_zero: # Quantize zero + q_zero_tensor = HQQTensorHandle.quantize( + float_tensor=self.q_weight.zero, + tensor_quant_config=zero_quant_config, + ) + self.q_weight.zero = q_zero_tensor + self.quantized = True + + def dequantize_weight(self): + assert self.quantized, "model was not quantized" + # TODO: move below logic into `HQQTensorHandle` + if self.q_weight.is_scale_quantized(): + scale_qdq = HQQTensorHandle.dequantize(self.q_weight.scale) + self.q_weight.scale = scale_qdq + + if self.q_weight.is_zero_quantized(): + zero_qdq = HQQTensorHandle.dequantize(self.q_weight.zero) + self.q_weight.zero = zero_qdq + + W_qdq = HQQTensorHandle.dequantize(self.q_weight) + return W_qdq + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = torch.matmul(input, self.dequantize_weight().t()) + if self.bias is not None: + out += self.bias + return out + + @classmethod + def from_float( + cls, + float_module: torch.nn.Linear, + quant_config: HQQModuleConfig = default_hqq_module_config, + ): + # Create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=float_module.bias is not None, + ) + new_mod.requires_grad_ = False + # Construct the q weight frpm float weight + new_mod.quantize_weight(float_module.weight, quant_config=quant_config) + # Update the linear module attributes + new_mod.in_features = float_module.in_features + new_mod.out_features = float_module.out_features + new_mod.weight = None + new_mod.bias = float_module.bias + if hqq_global_option.use_half and new_mod.bias is not None: + new_mod.bias = torch.nn.Parameter(float_module.bias.half()) + # TODO: refine it to support cuda/hpu/cpu + device_to_use = next(float_module.parameters()).device + if hqq_global_option.use_half: + new_mod.q_weight = new_mod.q_weight.half() + new_mod.to(device_to_use) + new_mod.q_weight.to(device_to_use) + # !!! Delete the float explicitly to save memory + del float_module + return new_mod diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py new file mode 100644 index 00000000000..ba03a4844c7 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023-2024 Mobiusml and Intel Corporation +# +# This code is based on Mobiusml's HQQ library. +# https://github.com/mobiusml/hqq +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + +from neural_compressor.torch.utils import logger + +from .auto_accelerator import auto_detect_accelerator + + +# Proximal solver || W - dequantize(quantize(W))||_p^p +@torch.inference_mode() +def optimize_weights_proximal_legacy( + tensor, + scale, + zero, + min_max, + axis=0, + device="cuda", + opt_params={"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20}, + verbose=False, +): + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + device = auto_detect_accelerator().current_device() + + # TODO: refine it for cpu device + if auto_detect_accelerator().name() == "cuda": + dtype = torch.float16 + else: + dtype = torch.float32 + W_f = tensor.to(dtype).to(device) + scale = scale.to(dtype).to(device) + zero = zero.to(dtype).to(device) + + if lp_norm == 1: + shrink_op = lambda x, beta: torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + shrink_op = lambda x, beta, p=lp_norm: torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), p - 1) + ) + + best_error = 1e4 + for i in range(iters): + W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1]) + W_r = (W_q - zero) / scale + W_e = shrink_op(W_f - W_r, beta) + zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(W_f - W_r).mean()) + if verbose: + logger.info(i, np.round(current_error, 6)) + if current_error < best_error: + best_error = current_error + else: + break + + scale = scale.to(tensor.device) + zero = zero.to(tensor.device) + del W_f, W_q, W_r, W_e + auto_detect_accelerator().empty_cache() + + return scale, zero + + +optimize_weights_proximal = optimize_weights_proximal_legacy diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py new file mode 100644 index 00000000000..950fa2243dc --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass +from typing import Tuple, Union + +import torch + +__all__ = [ + "QTensor", + "QTensorMetaInfo", +] + + +@dataclass +class QTensorMetaInfo: + nbits: int + group_size: int + shape: Tuple + axis: int + packing: bool + + def to_dict(self): + return asdict(self) + + +class QTensor: + val: torch.Tensor + scale: Union[torch.Tensor, "QTensor"] = None + zero: Union[torch.Tensor, "QTensor"] = None + meta_info: QTensorMetaInfo = None + """ + val: torch.Tensor + scale: + val: torch.Tensor + scale: torch.Tensor + zero: torch.Tensor + zero: + torch.Tensor + """ + + def __init__(self, val, scale=None, zero=None, meta_info=None): + self.val = val + self.scale = scale + self.zero = zero + self.meta_info = meta_info + + def is_scale_quantized(self) -> bool: + return isinstance(self.scale, QTensor) + + def is_zero_quantized(self) -> bool: + return isinstance(self.zero, QTensor) + + def _get_scale_repr(self) -> str: + if not self.is_scale_quantized(): + if self.scale is not None: + return ( + f"scale_shape={self.scale.shape}, " + f"scale_dtype={self.scale.dtype}, " + f"scale_device={self.scale.device}\n" + ) + else: + return "scale is None\n" + else: + return self.scale.__repr__() + "\n" + + def _get_zero_repr(self) -> str: + if not self.is_zero_quantized(): + if self.zero is not None: + return ( + f"zero_shape={self.zero.shape}, " + f"zero_dtype={self.zero.dtype}, " + f"zero_device={self.zero.device}\n" + ) + else: + return "zero is None\n" + else: + return self.zero.__repr__() + "\n" + + def __repr__(self) -> str: + # TODO: refine it later + return ( + f"QTensor(\n" + f"val_shape={self.val.shape}, val_dtype={self.val.dtype}, val_device={self.val.device}\n" + f"scale_quantized={self.is_scale_quantized()},\n" + f"zero_quantized={self.is_zero_quantized()},\n" + f"zero=({self._get_zero_repr()})" + f"scale=({self._get_scale_repr()})" + f"meta_info={self.meta_info}\n)" + ) + + def to(self, *args, **kwargs): + self.val = self.val.to(*args, **kwargs) + self.scale = self.scale.to(*args, **kwargs) + self.zero = self.zero.to(*args, **kwargs) + return self + + def half(self): + # TODO: refine it later + if self.val.dtype == torch.float32: + self.val = self.val.half() + if self.scale is not None: + self.scale = self.scale.half() + if self.zero is not None: + self.zero = self.zero.half() + return self diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py b/neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py new file mode 100644 index 00000000000..de8062926d7 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HQQ quantization APIs.""" + +import torch + +from neural_compressor.torch.utils import logger + +from .config import HQQModuleConfig, QTensorConfig +from .quantizer import HQQuantizer + +__all__ = ["hqq_quantize"] + + +def _convert_hqq_module_config(config) -> HQQModuleConfig: + # * 3.x API use `bits` for woq while HQQ internal API use `nbits` + nbits = config.bits + group_size = config.group_size + quant_zero = config.quant_zero + quant_scale = config.quant_scale + scale_quant_group_size = config.scale_quant_group_size + + weight_qconfig = QTensorConfig( + nbits=nbits, channel_wise=True, group_size=group_size, optimize=True, round_zero=True if nbits == 4 else False + ) + zero_qconfig = None + if quant_zero: + zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False) + scale_qconfig = None + if quant_scale: + scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False) + hqq_module_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) + logger.debug(hqq_module_config) + return hqq_module_config + + +def _parse_hqq_configs_mapping(configs_mapping): + qconfig_mapping = {} + for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config.skip_lm_head and "lm_head" in op_name: + logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name) + continue + qconfig_mapping[op_name] = _convert_hqq_module_config(quant_config) + return qconfig_mapping + + +@torch.no_grad() +def hqq_quantize(model: torch.nn.Module, configs_mapping, *args, **kwargs) -> torch.nn.Module: + qconfig_mapping = _parse_hqq_configs_mapping(configs_mapping) + hqq_quantizer = HQQuantizer(qconfig_mapping) + q_model = hqq_quantizer.prepare(model) + return q_model diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py new file mode 100644 index 00000000000..b0d9b6a37e4 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple + +import torch + +from neural_compressor.torch.utils import logger + +from .auto_accelerator import auto_detect_accelerator +from .config import ConfigMappingType, default_hqq_module_config, hqq_global_option +from .core import HQQLinear + + +def _has_child(module: torch.nn.Module) -> bool: + return len(list(module.named_children())) > 0 + + +def _replace_with_custom_fn_if_matches_filter( + model: torch.nn.Module, + replacement_fn: Callable, + filter_fn: Callable, + cur_fqn: str = "", + config_mapping: Optional[ConfigMappingType] = None, +) -> None: + """For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True`""" + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if filter_fn(child, new_fqn, config_mapping): + new_child = replacement_fn(child.to(auto_detect_accelerator().current_device()), new_fqn, config_mapping) + logger.debug("Quantize linear module %s.", new_fqn) + setattr(model, name, new_child) + elif not _has_child(child): # TODO: merge it into `filter_fn` + if hqq_global_option.use_half: + logger.debug("Half module %s.", new_fqn) + child = child.half() + new_child = child.to(auto_detect_accelerator().current_device()) + setattr(model, name, new_child) + else: + _replace_with_custom_fn_if_matches_filter( + model=child, + replacement_fn=replacement_fn, + filter_fn=filter_fn, + cur_fqn=new_fqn, + config_mapping=config_mapping, + ) + + +def patch_hqq_moduile(mod, config): + new_mod = HQQLinear.from_float(mod, config) + return new_mod + + +def filter_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> bool: + return isinstance(mod, torch.nn.Linear) and name in config_mapping + + +def replacement_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> torch.nn.Module: + config = config_mapping.get(name, None) + logger.debug("Replace module %s", name) + return patch_hqq_moduile(mod, config) + + +class EagerModeQuantizer: + def __init__(self, config_mapping) -> None: + self.config_mapping = config_mapping + + def prepare(self, model: torch.nn.Module, inplace=True) -> Optional[torch.nn.Module]: + pass + + def convert(self, model: torch.nn.Module, inplace=True) -> Optional[torch.nn.Module]: + pass + + def save(self): + pass + + +class HQQuantizer(EagerModeQuantizer): + def __init__(self, config_mapping: ConfigMappingType) -> None: + super().__init__(config_mapping) + + def prepare(self, model: torch.nn.Module, inplace=True) -> Optional[torch.nn.Module]: + _replace_with_custom_fn_if_matches_filter( + model, replacement_fn=replacement_fn, filter_fn=filter_fn, config_mapping=self.config_mapping + ) + return model + + def save(self, model, path): + # TODO: to implement it in the next PR + pass diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/utility.py b/neural_compressor/torch/algorithms/weight_only/hqq/utility.py new file mode 100644 index 00000000000..9c9b3700cf6 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/hqq/utility.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import time + +import numpy as np +import psutil +import torch + +from neural_compressor.torch.utils import logger + +__all__ = [ + "is_divisible", + "dump_elapsed_time", +] + + +def is_divisible(val1, val2): + return int(val2 * np.ceil(val1 / val2)) == val1 + + +def see_cuda_memory_usage(message, force=False): # pragma: no cover + # Copied from https://github.com/microsoft/DeepSpeed + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + + # logger.info message except when distributed but not rank 0 + logger.info(message) + logger.info( + f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ + Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ + CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \ + Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB " + ) + vm_stats = psutil.virtual_memory() + used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) + logger.info(f"CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%") + + # get the peak memory to report correct data, so reset the counter for the next call + torch.cuda.reset_peak_memory_stats() + + +def dump_elapsed_time(customized_msg=""): + """Get the elapsed time for decorated functions. + + Args: + customized_msg (string, optional): The parameter passed to decorator. Defaults to None. + """ + + def f(func): + def fi(*args, **kwargs): + start = time.time() + res = func(*args, **kwargs) + end = time.time() + logger.info( + "%s elapsed time: %s ms" + % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) + ) + return res + + return fi + + return f diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 73902a892cc..2269b292da7 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -25,6 +25,8 @@ get_default_static_config, SmoothQuantConfig, get_default_sq_config, + HQQConfig, + get_default_hqq_config, ) from neural_compressor.torch.quantization.autotune import ( diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 4d3f3959ce8..9680b834f66 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Any, Callable, Dict, Tuple import torch -from neural_compressor.common.utils import AWQ, FP8_QUANT, GPTQ, RTN # unified namespace -from neural_compressor.torch.algorithms.weight_only import awq_quantize, gptq_quantize, rtn_quantize -from neural_compressor.torch.quantization import AWQConfig, GPTQConfig, RTNConfig +from neural_compressor.common.utils import AWQ, FP8_QUANT, GPTQ, HQQ, RTN +from neural_compressor.torch.algorithms.weight_only import awq_quantize, gptq_quantize, hqq_quantize, rtn_quantize +from neural_compressor.torch.quantization import AWQConfig, GPTQConfig, HQQConfig, RTNConfig from neural_compressor.torch.utils import logger, register_algo @@ -154,6 +154,17 @@ def awq_quantize_entry( return model +###################### HQQ Algo Entry ################################## +@register_algo(name=HQQ) +@torch.no_grad() +def hqq_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, Callable], HQQConfig], *args, **kwargs +) -> torch.nn.Module: + logger.info("Quantize model with the HQQ algorithm.") + q_model = hqq_quantize(model, configs_mapping) + return q_model + + ###################### Habana FP8 Algo Entry ################################## from neural_compressor.torch.utils import is_hpex_available diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index c0097de2a48..b313dd4d15f 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -31,19 +31,22 @@ DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, + HQQ, OP_NAME_OR_MODULE_TYPE, RTN, SMOOTH_QUANT, STATIC_QUANT, ) from neural_compressor.torch.utils import is_hpex_available, logger -from neural_compressor.torch.utils.constants import PRIORITY_AWQ, PRIORITY_GPTQ, PRIORITY_RTN +from neural_compressor.torch.utils.constants import PRIORITY_AWQ, PRIORITY_GPTQ, PRIORITY_HQQ, PRIORITY_RTN __all__ = [ "RTNConfig", "get_default_rtn_config", "GPTQConfig", "get_default_gptq_config", + "HQQConfig", + "get_default_hqq_config", ] @@ -663,6 +666,76 @@ def get_default_sq_config() -> SmoothQuantConfig: return SmoothQuantConfig() +######################## HQQ Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ) +class HQQConfig(BaseConfig): + # Half-Quadratic Quantization (HQQ), more details: + # Blog: https://mobiusml.github.io/hqq_blog/ + # Code: https://github.com/mobiusml/hqq + + name = HQQ + params_list = [ + "bits", + "group_size", + "quant_zero", + "quant_scale", + "scale_quant_group_size", + "skip_lm_head", + ] + supported_configs: List[OperatorConfig] = [] + + def __init__( + self, + bits: int = 4, + group_size: int = 64, + quant_zero: bool = True, + quant_scale: bool = False, + scale_quant_group_size: int = 128, + skip_lm_head: bool = True, + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + super().__init__(white_list=white_list) + self.bits = bits + self.group_size = group_size + self.quant_zero = quant_zero + self.quant_scale = quant_scale + self.scale_quant_group_size = scale_quant_group_size + self.skip_lm_head = skip_lm_head + self._post_init() + + @staticmethod + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + white_list = (torch.nn.Linear,) + filter_result = [] + for op_name, module in model.named_modules(): + if isinstance(module, white_list): + pair = (op_name, type(module).__name__) + filter_result.append(pair) + return filter_result + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + # TODO: to be refined + supported_configs = [] + linear_hqq_config = HQQConfig() + operators = [torch.nn.Linear] + supported_configs.append(OperatorConfig(config=linear_hqq_config, operators=operators)) + cls.supported_configs = supported_configs + + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "HQQConfig", List["HQQConfig"]]: + return HQQConfig(bits=[4, 8]) + + +def get_default_hqq_config() -> HQQConfig: + """Generate the default HQQ config. + + Returns: + the default HQQ config. + """ + return HQQConfig() + + ######################## FP8 Config ############################### if is_hpex_available(): diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 1f50f3ce66d..8be5446efb9 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -43,6 +43,7 @@ } # Setting priorities for algorithms, a higher number indicates a higher priority. -PRIORITY_RTN = 80 PRIORITY_GPTQ = 90 +PRIORITY_RTN = 80 +PRIORITY_HQQ = 75 PRIORITY_AWQ = 70 diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 41bc6230a1a..de7a7c2ceb4 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -13,12 +13,15 @@ # limitations under the License. -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Union import torch +from typing_extensions import TypeAlias from neural_compressor.common.utils import logger +OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]] + # Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) algos_mapping: Dict[str, Callable] = {} diff --git a/test/3x/torch/quantization/weight_only/hqq/test_auto_accelerator.py b/test/3x/torch/quantization/weight_only/hqq/test_auto_accelerator.py new file mode 100644 index 00000000000..46e4c792fe2 --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_auto_accelerator.py @@ -0,0 +1,68 @@ +import os + +import pytest +import torch + +from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import ( + accelerator_registry, + auto_detect_accelerator, +) + + +class Test_CPU_Accelerator: + @pytest.fixture + def force_use_cpu(self, monkeypatch): + # Force use CPU + monkeypatch.setenv("FORCE_DEVICE", "cpu") + + def test_cpu_accelerator(self, force_use_cpu): + print(f"FORCE_DEVICE: {os.environ.get('FORCE_DEVICE', None)}") + accelerator = auto_detect_accelerator() + assert accelerator.current_device() == "cpu", f"{accelerator.current_device()}" + assert accelerator.current_device_name() == "cpu" + assert accelerator.is_available() + assert accelerator.set_device(1) is None + assert accelerator.device() is None + assert accelerator.empty_cache() is None + assert accelerator.synchronize() is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +class Test_CUDA_Accelerator: + + @pytest.fixture + def force_use_cuda(self, monkeypatch): + # Force use CUDA + monkeypatch.setenv("FORCE_DEVICE", "cuda") + + def test_cuda_accelerator(self, force_use_cuda): + print(f"FORCE_DEVICE: {os.environ.get('FORCE_DEVICE', None)}") + accelerator = auto_detect_accelerator() + assert accelerator.current_device() == 0, f"{accelerator.current_device()}" + assert accelerator.current_device_name() == "cuda:0" + assert accelerator.device() is not None + assert accelerator.empty_cache() is None + assert accelerator.synchronize() is None + assert accelerator.set_device(0) is None + assert accelerator.device_name(0) == "cuda:0" + assert accelerator.is_available() is True + assert accelerator.name() == "cuda" + assert accelerator.device_name(1) == "cuda:1" + assert accelerator.set_device(1) is None + assert accelerator.device_name(1) == "cuda:1" + assert accelerator.current_device() == 1 + assert accelerator.current_device_name() == "cuda:1" + assert accelerator.synchronize() is None + assert accelerator.empty_cache() is None + + +class TestAutoAccelerator: + + @pytest.fixture + def set_cuda_available(self, monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + + def test_auto_accelerator(self, set_cuda_available): + accelerator = auto_detect_accelerator() + all_accelerators = accelerator_registry.get_sorted_accelerators() + assert accelerator.name() == all_accelerators[0]().name() diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py new file mode 100644 index 00000000000..bdfd2145aff --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py @@ -0,0 +1,44 @@ +from neural_compressor.torch.algorithms.weight_only.hqq.config import ( + HQQModuleConfig, + QTensorConfig, + default_hqq_module_config, + default_scale_quant_config, + default_weight_quant_config, + default_zero_quant_config, +) +from neural_compressor.torch.algorithms.weight_only.hqq.qtensor import QTensorMetaInfo + + +def test_default_hqq_module_config(): + config = default_hqq_module_config + print(config) + assert isinstance(config, HQQModuleConfig) + assert config.weight == default_weight_quant_config + assert config.zero == default_zero_quant_config + assert config.scale == default_scale_quant_config + + +def test_default_weight_quant_config(): + config = default_weight_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 4 + assert config.channel_wise is True + + +def test_default_zero_quant_config(): + config = default_zero_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 8 + assert config.channel_wise is False + + +def test_default_scale_quant_config(): + config = default_scale_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 8 + assert config.channel_wise is True + + +def test_qtensor_meta_info(): + meta_info = QTensorMetaInfo + print(meta_info) diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py new file mode 100644 index 00000000000..724af3a7e2c --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py @@ -0,0 +1,114 @@ +import os +from copy import deepcopy + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option +from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear + + +def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128): + # Parse config + weight_qconfig = QTensorConfig( + nbits=nbits, channel_wise=True, group_size=group_size, optimize=True, round_zero=True if nbits == 4 else False + ) + zero_qconfig = None + if quant_zero: + zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False) + scale_qconfig = None + if quant_scale: + scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False) + hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) + device = "cpu" + + # Create HQQ Linear + bs = 4 + in_features = 64 + out_features = 128 + float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features) + if hqq_global_option.use_half: + print(f"hqq_global_option use half: {hqq_global_option.use_half}") + float_linear = float_linear.half() + float_linear.to(device) + float_linear_copy = deepcopy(float_linear) + hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config) + + # Forward + input = torch.randn(bs, in_features, device=device) + if hqq_global_option.use_half: + input = input.half() + float_output = float_linear(input) + input_for_hqq = deepcopy(input) + hqq_output = hqq_linear(input_for_hqq) + hqq_output_2 = hqq_linear(input_for_hqq) + torch.allclose(float_output, hqq_output, atol=0.5) + torch.allclose(hqq_output, hqq_output_2) + del float_linear, hqq_linear + del float_output, hqq_output, hqq_output_2 + + +class TestHQQCPU: + + @classmethod + def setup_class(cls): + torch.manual_seed(0) + + @pytest.fixture + def force_use_cpu(self, monkeypatch): + # Force use CPU + monkeypatch.setenv("FORCE_DEVICE", "cpu") + + @pytest.fixture + def force_not_half(self, monkeypatch): + monkeypatch.setattr(hqq_global_option, "use_half", False) + + def test_hqq_quant(self, force_use_cpu, force_not_half): + from neural_compressor.torch.quantization import get_default_hqq_config, quantize + + hqq_global_option.use_half = False + model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu") + # test_default_config + quant_config = get_default_hqq_config() + model = quantize(model, quant_config) + q_label = model(example_inputs)[0] + print(q_label) + + @pytest.mark.parametrize( + "nbits, group_size, quant_zero, quant_scale, scale_quant_group_size", + [ + (4, 64, True, False, 128), + (4, 64, False, False, 128), + (4, 64, True, True, 128), + (4, 64, False, True, 128), + (8, 64, True, False, 128), + (8, 64, False, False, 128), + (8, 64, True, True, 128), + (8, 64, False, True, 128), + (4, 64, True, False, 64), + (4, 64, False, False, 64), + (4, 64, True, True, 64), + (4, 64, False, True, 64), + ], + ) + def test_hqq_module_cpu( + self, force_use_cpu, force_not_half, nbits, group_size, quant_zero, quant_scale, scale_quant_group_size + ): + _common_cpu_test( + nbits=nbits, + group_size=group_size, + quant_zero=quant_zero, + quant_scale=quant_scale, + scale_quant_group_size=scale_quant_group_size, + ) + + +# _common_cpu_test( +# nbits=4, +# group_size=64, +# quant_zero=False, +# quant_scale=False, +# scale_quant_group_size=128 +# ) diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cuda.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cuda.py new file mode 100644 index 00000000000..bb45b971ffa --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cuda.py @@ -0,0 +1,119 @@ +from copy import deepcopy + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import auto_detect_accelerator +from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option +from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear +from neural_compressor.torch.algorithms.weight_only.hqq.utility import see_cuda_memory_usage + + +def _common_cuda_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128): + # Parse config + weight_qconfig = QTensorConfig( + nbits=nbits, channel_wise=True, group_size=group_size, optimize=True, round_zero=True if nbits == 4 else False + ) + zero_qconfig = None + if quant_zero: + zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False) + scale_qconfig = None + if quant_scale: + scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False) + hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) + device = torch.cuda.current_device() + + # Create HQQ Linear + bs = 4 + in_features = 64 + out_features = 128 + see_cuda_memory_usage(message="Before create float linear") + float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features) + if hqq_global_option.use_half: + float_linear = float_linear.half() + see_cuda_memory_usage(message="After create float linear") + float_linear.to(device) + float_linear_copy = deepcopy(float_linear) + see_cuda_memory_usage(message="After copy the float linear") + hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config) + see_cuda_memory_usage(message="After create hqq linear") + + # Forward + input = torch.randn(bs, in_features, device=device) + if hqq_global_option.use_half: + input = input.half() + float_output = float_linear(input) + input_for_hqq = deepcopy(input) + hqq_output = hqq_linear(input_for_hqq) + hqq_output_2 = hqq_linear(input_for_hqq) + float_qdq_diff = 0.1 # hard code it first + torch.allclose(float_output, hqq_output, atol=float_qdq_diff) + torch.allclose(hqq_output, hqq_output_2) + del float_linear, hqq_linear + del float_output, hqq_output, hqq_output_2 + see_cuda_memory_usage("At the end of test") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +class TestHQQCUDA: + @classmethod + def setup_class(cls): + torch.manual_seed(0) + torch.cuda.manual_seed(0) + hqq_global_option.use_half = True + + def test_hqq_quant(self): + from neural_compressor.torch.quantization import get_default_hqq_config, quantize + + model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + example_inputs = torch.tensor( + [[10, 20, 30, 40, 50, 60]], dtype=torch.long, device=auto_detect_accelerator().current_device() + ) + # test_default_config + quant_config = get_default_hqq_config() + model = quantize(model, quant_config) + q_label = model(example_inputs)[0] + print(q_label) + + @pytest.mark.parametrize( + "nbits, group_size, quant_zero, quant_scale, scale_quant_group_size", + [ + (4, 64, True, False, 128), + (4, 64, False, False, 128), + (4, 64, True, True, 128), + (4, 64, False, True, 128), + (8, 64, True, False, 128), + (8, 64, False, False, 128), + (8, 64, True, True, 128), + (8, 64, False, True, 128), + (4, 64, True, False, 64), + (4, 64, False, False, 64), + (4, 64, True, True, 64), + (4, 64, False, True, 64), + ], + ) + def test_hqq_module_cuda( + self, + nbits, + group_size, + quant_zero, + quant_scale, + scale_quant_group_size, + ): + _common_cuda_test( + nbits=nbits, + group_size=group_size, + quant_zero=quant_zero, + quant_scale=quant_scale, + scale_quant_group_size=scale_quant_group_size, + ) + + +# _common_cuda_test( +# nbits=4, +# group_size=64, +# quant_zero=False, +# quant_scale=False, +# scale_quant_group_size=128 +# ) diff --git a/test/3x/torch/quantization/weight_only/hqq/test_packer.py b/test/3x/torch/quantization/weight_only/hqq/test_packer.py new file mode 100644 index 00000000000..d25073ce74b --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_packer.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from neural_compressor.torch.algorithms.weight_only.hqq.bitpack import Packer + + +@pytest.mark.parametrize("nbits", [2, 4, 8]) +def test_packer(nbits): + # TODO: add test for 3 bits + range_max = 2**nbits + dims = 16 + W = torch.randint(0, range_max, (dims, dims)).to(torch.uint8) + W_pack = Packer.get_pack_fn(nbits)(W) + W_pack_unpack = Packer.get_unpack_fn(nbits)(W_pack) + assert torch.allclose(W, W_pack_unpack) + print("Packer test passed!") diff --git a/test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py b/test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py new file mode 100644 index 00000000000..0548c10e3f1 --- /dev/null +++ b/test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from neural_compressor.torch.algorithms.weight_only.hqq.qtensor import QTensor, QTensorMetaInfo + + +class TestQTensor: + def test_q_tensor(self): + in_feats = 3 + out_feats = 4 + + val = torch.randn(out_feats, in_feats) + scale = torch.randn(out_feats) + zero = torch.randint(1, 10, (out_feats,)) + q_tensor_meta = QTensorMetaInfo(nbits=4, group_size=64, shape=(out_feats, in_feats), axis=0, packing=False) + q_tensor = QTensor(val, scale, zero, q_tensor_meta) + print(q_tensor) + q_tensor_half = q_tensor.half() + print(q_tensor_half) + + def test_q_tensor2(self): + in_feats = 64 + out_feats = 64 + + val = torch.randn(out_feats, in_feats) + scale = torch.randn(out_feats) + zero = torch.randint(1, 10, (out_feats,)) + q_tensor_meta = QTensorMetaInfo(nbits=4, group_size=64, shape=(out_feats, in_feats), axis=0, packing=False) + q_tensor = QTensor(val, scale, zero, q_tensor_meta) + q_scale_meta = QTensorMetaInfo(nbits=8, group_size=64, shape=(out_feats,), axis=0, packing=False) + q_scale_scale = torch.randn(out_feats) + q_scale_zero = torch.randint(1, 10, (1,)) + q_scale = QTensor(scale, q_scale_scale, q_scale_zero, q_tensor_meta) + q_tensor.scale = q_scale + print(q_tensor) + print(q_tensor.half()) + + def test_qtensor_meta_info(self): + in_feats = 64 + out_feats = 64 + meta_config = QTensorMetaInfo(nbits=4, group_size=64, shape=(out_feats, in_feats), axis=0, packing=False) + print(meta_config) + print(meta_config.to_dict) + assert meta_config.to_dict() == { + "nbits": 4, + "group_size": 64, + "shape": (out_feats, in_feats), + "axis": 0, + "packing": False, + } diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index 793bdfc0350..81a6eaa0301 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,2 +1,5 @@ +numpy +psutil pytest +torch transformers diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 760a2d27727..ee3d36e4f24 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -7,9 +7,11 @@ from neural_compressor.torch.quantization import ( AWQConfig, GPTQConfig, + HQQConfig, RTNConfig, SmoothQuantConfig, StaticQuantConfig, + get_default_hqq_config, get_default_rtn_config, quantize, ) @@ -288,6 +290,12 @@ def test_smooth_quant_config(self): sq_config2 = SmoothQuantConfig.from_dict(quant_config_dict["sq"]) self.assertEqual(sq_config1.to_dict(), sq_config2.to_dict()) + def test_hqq_config(self): + hqq_config = HQQConfig(bits=4, group_size=64, quant_zero=True) + quant_config_dict = {"hqq": {"bits": 4, "group_size": 64, "quant_zero": True}} + hqq_config2 = HQQConfig.from_dict(quant_config_dict["hqq"]) + self.assertEqual(hqq_config.to_dict(), hqq_config2.to_dict()) + class TestQuantConfigForAutotune(unittest.TestCase): def test_expand_config(self):