Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cross-devices Half-Quadratic Quantization for LLMs #1597

Merged
merged 56 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d29105b
add hqq module
yiliu30 Jan 28, 2024
cad4942
add hqq quantizer
yiliu30 Jan 28, 2024
a292bac
fixed some UTs
yiliu30 Jan 28, 2024
2bcd16a
add more test
yiliu30 Jan 28, 2024
cb79fff
add more test
yiliu30 Jan 28, 2024
f2a37f8
add cuda support
yiliu30 Jan 29, 2024
4d5518c
refactor code
yiliu30 Jan 29, 2024
0131d27
add packing
yiliu30 Jan 30, 2024
f497e18
add hqq api
yiliu30 Jan 31, 2024
4025f96
refine the code
yiliu30 Jan 31, 2024
a5b8820
rename test replace to test api
yiliu30 Jan 31, 2024
ac1cc3c
rename QuantTensorCOnfig into QTensorConfig
yiliu30 Jan 31, 2024
2ba4ccb
Merge branch 'master' into hqq
yiliu30 Jan 31, 2024
444c93c
Merge branch 'master' into hqq
yiliu30 Jan 31, 2024
28089c2
Merge branch 'ly/hqq' of https://github.com/intel/neural-compressor i…
yiliu30 Jan 31, 2024
22b17af
enhance accelerator
yiliu30 Jan 31, 2024
201dc9a
clean code
yiliu30 Jan 31, 2024
99acbc9
fix half issue
yiliu30 Jan 31, 2024
4944fc2
separate qtensor
yiliu30 Feb 1, 2024
e89c0dc
refine packer
yiliu30 Feb 1, 2024
989332a
Add HQQ config
yiliu30 Feb 1, 2024
bbf1173
add hqq to algo entry
yiliu30 Feb 1, 2024
5a8ca50
force half for cpu half
yiliu30 Feb 1, 2024
3a7e149
separate cpu and cuda UTs
yiliu30 Feb 1, 2024
068195b
add inc example
yiliu30 Feb 1, 2024
00c33cd
update api
yiliu30 Feb 1, 2024
78dad4e
refine code
yiliu30 Feb 1, 2024
f9feb14
disable some tests
yiliu30 Feb 1, 2024
e296332
disable some tests
yiliu30 Feb 1, 2024
0a6f8c9
resolve confilcts
yiliu30 Feb 2, 2024
0f8b535
fixed some bugs
yiliu30 Feb 2, 2024
171457b
add more test
yiliu30 Feb 2, 2024
f20c992
merge with masetr
yiliu30 Feb 2, 2024
24212ba
clean code
yiliu30 Feb 2, 2024
3debc1c
clean code
yiliu30 Feb 2, 2024
9a50ce3
clean code
yiliu30 Feb 2, 2024
186369f
remove unused import
yiliu30 Feb 2, 2024
414658f
add q tensor test
yiliu30 Feb 2, 2024
20bfa0b
fix pylint and add packages
yiliu30 Feb 2, 2024
36ce8bb
fixed UTs bugs
yiliu30 Feb 2, 2024
7856057
Merge branch 'master' into hqq3
yiliu30 Feb 2, 2024
b73e761
refine the UTs
yiliu30 Feb 2, 2024
1495a64
fixed the import erro
yiliu30 Feb 2, 2024
de022b9
fixed the UTs
yiliu30 Feb 2, 2024
14a0c5f
remove unused code
yiliu30 Feb 2, 2024
9768611
remove some unused code
yiliu30 Feb 2, 2024
374d10b
enhance UTs
yiliu30 Feb 2, 2024
4498ee8
cast bias
yiliu30 Feb 3, 2024
a546cf5
update the license
yiliu30 Feb 5, 2024
62607d4
update some func names
yiliu30 Feb 5, 2024
f6d1cf0
resolve the conflicts
yiliu30 Feb 6, 2024
0623ab3
align quantize entry
yiliu30 Feb 6, 2024
c5d9d7a
update logger import path
yiliu30 Feb 6, 2024
b6812af
rename nbits to bits
yiliu30 Feb 6, 2024
7e6f916
fixed typo
yiliu30 Feb 6, 2024
7deea17
Merge branch 'master' into ly/hqq3
yiliu30 Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@
from .utility import *
from .rtn import rtn_quantize
from .gptq import gptq_quantize

from .hqq.quantizer import HQQuantizer
from .hqq.config import HQQModuleConfig, QTensorConfig
xin3he marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
223 changes: 223 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/auto_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2023-2024 Microsoft Corporation and Intel Corporation

# This code is based on Microsoft Corporation's DeepSpeed library and
# the accelerators implementation in this library. It has been modified
# from its original forms to simplify and adapt it for use in
# the Intel® Neural Compressor.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTICE: The design adapted from:
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
# TODO: move it into torch/utils


# To keep it simply, only add the APIs we need.

import os
from abc import ABC, abstractmethod
from typing import Any, Callable, List

import torch

from neural_compressor.common import logger
xin3he marked this conversation as resolved.
Show resolved Hide resolved

PRIORITY_CUDA = 100
PRIORITY_CPU = 90


class AcceleratorRegistry:
registered_accelerators = {}

@classmethod
def register_accelerator_impl(cls, name: str, priority: float = 0):
"""Register new accelerator implementation.

Usage example:
@AcceleratorRegistry.register_accelerator(name="cpu", priority=100)
class CPU_Accelerator:
...

Args:
name: the accelerator name.
priority: priority: the priority of the accelerator. A larger number indicates a higher priority,
"""

def decorator(accelerator_cls):
cls.registered_accelerators.setdefault(name, {})
cls.registered_accelerators[name] = (accelerator_cls, priority)
return accelerator_cls

return decorator

@classmethod
def get_sorted_accelerators(cls) -> List["Auto_Accelerator"]:
"""Get registered accelerators sorted by priority."""
accelerator_pairs = cls.registered_accelerators.values()
sorted_accelerators_pairs = sorted(accelerator_pairs, key=lambda x: x[1], reverse=True)
sorted_accelerators = [pair[0] for pair in sorted_accelerators_pairs]
return sorted_accelerators

@classmethod
def get_accelerator_cls_by_name(cls, name: str) -> "Auto_Accelerator":
"""Get accelerator by name."""
accelerator_cls, _ = cls.registered_accelerators.get(name, (None, None))
return accelerator_cls


accelerator_registry = AcceleratorRegistry()


def register_accelerator(name: str, priority: float = 0) -> Callable[..., Any]:
"""Register new accelerator.

Usage example:
@register_accelerator(name="cuda", priority=100)
class CUDA_Accelerator:
...

Args:
name: the accelerator name.
priority: the priority of the accelerator. A larger number indicates a higher priority,
"""

return accelerator_registry.register_accelerator_impl(name=name, priority=priority)


class Auto_Accelerator(ABC):
@classmethod
@abstractmethod
def is_available(cls) -> bool:
pass

@abstractmethod
def name(self) -> str:
pass

@abstractmethod
def device_name(self, device_indx) -> str:
pass

@abstractmethod
def set_device(self, device_index):
pass

@abstractmethod
def current_device(self):
pass

@abstractmethod
def current_device_name(self):
pass

@abstractmethod
def device(self, device_index=None):
pass

@abstractmethod
def empty_cache(self):
pass

@abstractmethod
def synchronize(self):
pass


@register_accelerator(name="cpu", priority=PRIORITY_CPU)
class CPU_Accelerator(Auto_Accelerator):
def __init__(self) -> None:
self._name = "cpu"

def name(self) -> str:
return self._name

@classmethod
def is_available(cls) -> bool:
return True

def device_name(self, device_indx) -> str:
return "cpu"

def set_device(self, device_index):
pass

def current_device(self):
return "cpu"

def current_device_name(self):
return "cpu"

def device(self, device_index=None):
pass

def empty_cache(self):
pass

def synchronize(self):
pass


@register_accelerator(name="cuda", priority=PRIORITY_CUDA)
class CUDA_Accelerator(Auto_Accelerator):
def __init__(self) -> None:
self._name = "cuda"

def name(self) -> str:
return self._name

@classmethod
def is_available(cls) -> bool:
return torch.cuda.is_available()

def device_name(self, device_indx) -> str:
if device_indx is None:
return "cuda"
return f"cuda:{device_indx}"

def synchronize(self):
return torch.cuda.synchronize()

def set_device(self, device_index):
return torch.cuda.set_device(device_index)

def current_device(self):
return torch.cuda.current_device()

def current_device_name(self):
return "cuda:{}".format(torch.cuda.current_device())

def device(self, device_index=None):
return torch.cuda.device(device_index)

def empty_cache(self):
return torch.cuda.empty_cache()


def auto_detect_accelerator() -> Auto_Accelerator:
# if runtime_accelerator.accelerator:
# return runtime_accelerator.accelerator
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
if accelerator_cls.is_available():
logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__)
accelerator = accelerator_cls()
return accelerator


# Force use cpu accelerator even if cuda is available.
# FORCE_DEVICE = "cpu" python ...
# or
# CUDA_VISIBLE_DEVICES="" python ...
144 changes: 144 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2023-2024 Mobiusml and Intel Corporation

# This code is based on Mobiusml's HQQ library. It has been modified
# from its original forms to simplify and adapt it for use in
# the Intel® Neural Compressor.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Notice: Copied from from https://github.com/mobiusml/hqq
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################

import numpy as np
import torch

from .utility import is_divisible

__all__ = ["Packer"]


# Bit packing logic. format: pack/unpack_nBits_target-<uint8 or int32>
class BitPack:
# 8-bit
################################################
@staticmethod
def pack_8bit_u8(W_q):
return W_q.to(torch.uint8)

@staticmethod
def unpack_8bit_u8(W_q):
return W_q

# 4-bit
################################################
@staticmethod
def pack_4bit_u8(W_q): # uint8 > uint8/2
W_q = W_q.to(torch.uint8)
_step = int(len(W_q) / 2)
return (W_q[:_step] << 4) | W_q[_step:]

# A bit faster than the _cat version
@staticmethod
def unpack_4bit_u8(W_q): # uint8/2 > uint8
_step = W_q.shape[0]
tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b11110000) >> 4
tmp[_step:] = W_q & 0b00001111
return tmp

# 2-bit
################################################
@staticmethod
def pack_2bit_u8(W_q): # uint8 > uint8/4
W_q = W_q.to(torch.uint8)
_step = int(len(W_q) / 4)
return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :]

# A bit faster than the _cat version
@staticmethod
def unpack_2bit_u8(W_q):
_step = W_q.shape[0]
tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b11000000) >> 6
tmp[_step : 2 * _step] = (W_q & 0b00110000) >> 4
tmp[2 * _step : 3 * _step] = (W_q & 0b00001100) >> 2
tmp[3 * _step :] = W_q & 0b00000011
return tmp

# 3bit
################################################
@staticmethod
def pack_3bit_32(W_q_in):
W_q = torch.zeros(
[int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32
)
W_q[: len(W_q_in)] = W_q_in
_step = int(len(W_q) / 10)
W_q = (
(W_q[:_step] << 27)
| (W_q[_step : _step * 2] << 24)
| (W_q[_step * 2 : _step * 3] << 21)
| (W_q[_step * 3 : _step * 4] << 18)
| (W_q[_step * 4 : _step * 5] << 15)
| (W_q[_step * 5 : _step * 6] << 12)
| (W_q[_step * 6 : _step * 7] << 9)
| (W_q[7 * _step : _step * 8] << 6)
| (W_q[_step * 8 : _step * 9] << 3)
| (W_q[_step * 9 :])
)
return W_q

# A bit faster than _cat version
@staticmethod
def unpack_3bit_32(W_q):
_step = W_q.shape[0]
tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27
tmp[1 * _step : 2 * _step] = (W_q & 0b00000111000000000000000000000000) >> 24
tmp[2 * _step : 3 * _step] = (W_q & 0b00000000111000000000000000000000) >> 21
tmp[3 * _step : 4 * _step] = (W_q & 0b00000000000111000000000000000000) >> 18
tmp[4 * _step : 5 * _step] = (W_q & 0b00000000000000111000000000000000) >> 15
tmp[5 * _step : 6 * _step] = (W_q & 0b00000000000000000111000000000000) >> 12
tmp[6 * _step : 7 * _step] = (W_q & 0b00000000000000000000111000000000) >> 9
tmp[7 * _step : 8 * _step] = (W_q & 0b00000000000000000000000111000000) >> 6
tmp[8 * _step : 9 * _step] = (W_q & 0b00000000000000000000000000111000) >> 3
tmp[9 * _step :] = W_q & 0b00000000000000000000000000000111
return tmp


class Packer:
# TODO: Refine the packer
bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"}

pack_fn_mapping = {
"8bit_u8": BitPack.pack_8bit_u8,
"4bit_u8": BitPack.pack_4bit_u8,
"3bit_32": BitPack.pack_3bit_32,
"2bit_u8": BitPack.pack_2bit_u8,
}

unpack_fn_mapping = {
"8bit_u8": BitPack.unpack_8bit_u8,
"4bit_u8": BitPack.unpack_4bit_u8,
"3bit_32": BitPack.unpack_3bit_32,
"2bit_u8": BitPack.unpack_2bit_u8,
}

@staticmethod
def get_pack_fn(nbits: int):
return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]]

@staticmethod
def get_unpack_fn(nbits: int):
return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]]
Loading