Skip to content

Commit

Permalink
Enable cross-devices Half-Quadratic Quantization for LLMs (#1597)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Feb 7, 2024
1 parent c1f23ce commit 07f940c
Show file tree
Hide file tree
Showing 24 changed files with 1,736 additions and 7 deletions.
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
from .rtn import rtn_quantize
from .gptq import gptq_quantize
from .awq import awq_quantize
from .hqq import hqq_quantize
from .modules import WeightOnlyLinear
from .utility import *
17 changes: 17 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
from .quant_api import hqq_quantize
223 changes: 223 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/auto_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2023-2024 Microsoft Corporation and Intel Corporation

# This code is based on Microsoft Corporation's DeepSpeed library and
# the accelerators implementation in this library. It has been modified
# from its original forms to simplify and adapt it for use in
# the Intel® Neural Compressor.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTICE: The design adapted from:
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
# TODO: move it into torch/utils


# To keep it simply, only add the APIs we need.

import os
from abc import ABC, abstractmethod
from typing import Any, Callable, List

import torch

from neural_compressor.torch.utils import logger

PRIORITY_CUDA = 100
PRIORITY_CPU = 90


class AcceleratorRegistry:
registered_accelerators = {}

@classmethod
def register_accelerator_impl(cls, name: str, priority: float = 0):
"""Register new accelerator implementation.
Usage example:
@AcceleratorRegistry.register_accelerator(name="cpu", priority=100)
class CPU_Accelerator:
...
Args:
name: the accelerator name.
priority: priority: the priority of the accelerator. A larger number indicates a higher priority,
"""

def decorator(accelerator_cls):
cls.registered_accelerators.setdefault(name, {})
cls.registered_accelerators[name] = (accelerator_cls, priority)
return accelerator_cls

return decorator

@classmethod
def get_sorted_accelerators(cls) -> List["Auto_Accelerator"]:
"""Get registered accelerators sorted by priority."""
accelerator_pairs = cls.registered_accelerators.values()
sorted_accelerators_pairs = sorted(accelerator_pairs, key=lambda x: x[1], reverse=True)
sorted_accelerators = [pair[0] for pair in sorted_accelerators_pairs]
return sorted_accelerators

@classmethod
def get_accelerator_cls_by_name(cls, name: str) -> "Auto_Accelerator":
"""Get accelerator by name."""
accelerator_cls, _ = cls.registered_accelerators.get(name, (None, None))
return accelerator_cls


accelerator_registry = AcceleratorRegistry()


def register_accelerator(name: str, priority: float = 0) -> Callable[..., Any]:
"""Register new accelerator.
Usage example:
@register_accelerator(name="cuda", priority=100)
class CUDA_Accelerator:
...
Args:
name: the accelerator name.
priority: the priority of the accelerator. A larger number indicates a higher priority,
"""

return accelerator_registry.register_accelerator_impl(name=name, priority=priority)


class Auto_Accelerator(ABC):
@classmethod
@abstractmethod
def is_available(cls) -> bool:
pass

@abstractmethod
def name(self) -> str:
pass

@abstractmethod
def device_name(self, device_indx) -> str:
pass

@abstractmethod
def set_device(self, device_index):
pass

@abstractmethod
def current_device(self):
pass

@abstractmethod
def current_device_name(self):
pass

@abstractmethod
def device(self, device_index=None):
pass

@abstractmethod
def empty_cache(self):
pass

@abstractmethod
def synchronize(self):
pass


@register_accelerator(name="cpu", priority=PRIORITY_CPU)
class CPU_Accelerator(Auto_Accelerator):
def __init__(self) -> None:
self._name = "cpu"

def name(self) -> str:
return self._name

@classmethod
def is_available(cls) -> bool:
return True

def device_name(self, device_indx) -> str:
return "cpu"

def set_device(self, device_index):
pass

def current_device(self):
return "cpu"

def current_device_name(self):
return "cpu"

def device(self, device_index=None):
pass

def empty_cache(self):
pass

def synchronize(self):
pass


@register_accelerator(name="cuda", priority=PRIORITY_CUDA)
class CUDA_Accelerator(Auto_Accelerator):
def __init__(self) -> None:
self._name = "cuda"

def name(self) -> str:
return self._name

@classmethod
def is_available(cls) -> bool:
return torch.cuda.is_available()

def device_name(self, device_indx) -> str:
if device_indx is None:
return "cuda"
return f"cuda:{device_indx}"

def synchronize(self):
return torch.cuda.synchronize()

def set_device(self, device_index):
return torch.cuda.set_device(device_index)

def current_device(self):
return torch.cuda.current_device()

def current_device_name(self):
return "cuda:{}".format(torch.cuda.current_device())

def device(self, device_index=None):
return torch.cuda.device(device_index)

def empty_cache(self):
return torch.cuda.empty_cache()


def auto_detect_accelerator() -> Auto_Accelerator:
# if runtime_accelerator.accelerator:
# return runtime_accelerator.accelerator
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
if accelerator_cls.is_available():
logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__)
accelerator = accelerator_cls()
return accelerator


# Force use cpu accelerator even if cuda is available.
# FORCE_DEVICE = "cpu" python ...
# or
# CUDA_VISIBLE_DEVICES="" python ...
144 changes: 144 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2023-2024 Mobiusml and Intel Corporation

# This code is based on Mobiusml's HQQ library. It has been modified
# from its original forms to simplify and adapt it for use in
# the Intel® Neural Compressor.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Notice: Copied from from https://github.com/mobiusml/hqq
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################

import numpy as np
import torch

from .utility import is_divisible

__all__ = ["Packer"]


# Bit packing logic. format: pack/unpack_nBits_target-<uint8 or int32>
class BitPack:
# 8-bit
################################################
@staticmethod
def pack_8bit_u8(W_q):
return W_q.to(torch.uint8)

@staticmethod
def unpack_8bit_u8(W_q):
return W_q

# 4-bit
################################################
@staticmethod
def pack_4bit_u8(W_q): # uint8 > uint8/2
W_q = W_q.to(torch.uint8)
_step = int(len(W_q) / 2)
return (W_q[:_step] << 4) | W_q[_step:]

# A bit faster than the _cat version
@staticmethod
def unpack_4bit_u8(W_q): # uint8/2 > uint8
_step = W_q.shape[0]
tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b11110000) >> 4
tmp[_step:] = W_q & 0b00001111
return tmp

# 2-bit
################################################
@staticmethod
def pack_2bit_u8(W_q): # uint8 > uint8/4
W_q = W_q.to(torch.uint8)
_step = int(len(W_q) / 4)
return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :]

# A bit faster than the _cat version
@staticmethod
def unpack_2bit_u8(W_q):
_step = W_q.shape[0]
tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b11000000) >> 6
tmp[_step : 2 * _step] = (W_q & 0b00110000) >> 4
tmp[2 * _step : 3 * _step] = (W_q & 0b00001100) >> 2
tmp[3 * _step :] = W_q & 0b00000011
return tmp

# 3bit
################################################
@staticmethod
def pack_3bit_32(W_q_in):
W_q = torch.zeros(
[int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32
)
W_q[: len(W_q_in)] = W_q_in
_step = int(len(W_q) / 10)
W_q = (
(W_q[:_step] << 27)
| (W_q[_step : _step * 2] << 24)
| (W_q[_step * 2 : _step * 3] << 21)
| (W_q[_step * 3 : _step * 4] << 18)
| (W_q[_step * 4 : _step * 5] << 15)
| (W_q[_step * 5 : _step * 6] << 12)
| (W_q[_step * 6 : _step * 7] << 9)
| (W_q[7 * _step : _step * 8] << 6)
| (W_q[_step * 8 : _step * 9] << 3)
| (W_q[_step * 9 :])
)
return W_q

# A bit faster than _cat version
@staticmethod
def unpack_3bit_32(W_q):
_step = W_q.shape[0]
tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27
tmp[1 * _step : 2 * _step] = (W_q & 0b00000111000000000000000000000000) >> 24
tmp[2 * _step : 3 * _step] = (W_q & 0b00000000111000000000000000000000) >> 21
tmp[3 * _step : 4 * _step] = (W_q & 0b00000000000111000000000000000000) >> 18
tmp[4 * _step : 5 * _step] = (W_q & 0b00000000000000111000000000000000) >> 15
tmp[5 * _step : 6 * _step] = (W_q & 0b00000000000000000111000000000000) >> 12
tmp[6 * _step : 7 * _step] = (W_q & 0b00000000000000000000111000000000) >> 9
tmp[7 * _step : 8 * _step] = (W_q & 0b00000000000000000000000111000000) >> 6
tmp[8 * _step : 9 * _step] = (W_q & 0b00000000000000000000000000111000) >> 3
tmp[9 * _step :] = W_q & 0b00000000000000000000000000000111
return tmp


class Packer:
# TODO: Refine the packer
bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"}

pack_fn_mapping = {
"8bit_u8": BitPack.pack_8bit_u8,
"4bit_u8": BitPack.pack_4bit_u8,
"3bit_32": BitPack.pack_3bit_32,
"2bit_u8": BitPack.pack_2bit_u8,
}

unpack_fn_mapping = {
"8bit_u8": BitPack.unpack_8bit_u8,
"4bit_u8": BitPack.unpack_4bit_u8,
"3bit_32": BitPack.unpack_3bit_32,
"2bit_u8": BitPack.unpack_2bit_u8,
}

@staticmethod
def get_pack_fn(nbits: int):
return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]]

@staticmethod
def get_unpack_fn(nbits: int):
return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]]
Loading

0 comments on commit 07f940c

Please sign in to comment.