Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asym kernel issue by following autogptq's pr #137

Merged
merged 53 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
7d020db
fix a bug in example
wenhuach21 Mar 18, 2024
fbe69d5
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 19, 2024
596a18f
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 20, 2024
10add8c
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 21, 2024
003b60a
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 24, 2024
9d49514
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 25, 2024
3b7f386
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 27, 2024
d3f14df
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Mar 27, 2024
76e4d90
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 3, 2024
08e46ac
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 9, 2024
15b756b
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 9, 2024
40425bb
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 11, 2024
f0b9ad0
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 17, 2024
04e70ec
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 20, 2024
43811bb
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 24, 2024
c881719
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 25, 2024
54920e5
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 Apr 30, 2024
e2c2f56
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 1, 2024
4f718d4
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 1, 2024
022988a
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 9, 2024
b4eb679
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 10, 2024
0abbde8
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 16, 2024
aed8be0
fix gradient_accmulate bug in lm-head
wenhuach21 May 20, 2024
018eeb8
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 23, 2024
0fd4a92
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 24, 2024
47f5efe
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 24, 2024
f4814f2
correct the doc
wenhuach21 May 24, 2024
d92be56
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 25, 2024
3847bff
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 27, 2024
f102931
remove fp32 conversion as no need now
wenhuach21 May 27, 2024
77dca2c
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 27, 2024
d201814
update phi2 recipe and remove falcon data as we don't trust the qdq a…
wenhuach21 May 27, 2024
d25cfeb
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 28, 2024
bda3da9
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 28, 2024
a4de240
Merge branch 'main' of https://github.com/intel/auto-round
wenhuach21 May 29, 2024
0f48813
fix asym issue by following autogptq's pr
wenhuach21 May 29, 2024
839674a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
76b1254
add gptq license
wenhuach21 May 29, 2024
3260dfe
tmp commit
wenhuach21 May 30, 2024
1965a0e
support exllamav2
wenhuach21 May 30, 2024
1745d09
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
6bb563d
updated
wenhuach21 May 31, 2024
d0fde97
Merge branch 'fix_asym' of https://github.com/intel/auto-round into f…
wenhuach21 May 31, 2024
c445fbe
revert the change
wenhuach21 May 31, 2024
d1305d4
Merge branch 'main' into fix_asym
wenhuach21 May 31, 2024
787fa01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2024
5c94cea
revert the hook change
wenhuach21 May 31, 2024
5025bd4
fix bugs
wenhuach21 May 31, 2024
debc852
fix a bug
wenhuach21 May 31, 2024
47e501b
tiny change
wenhuach21 May 31, 2024
9045951
fix issues
wenhuach21 May 31, 2024
ee65eda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2024
e1a7233
reorg the code of cuda kernel
wenhuach21 Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ image presents an overview of AutoRound.
```bash
pip install -r requirements.txt
python setup.py install
or
pip install -vvv --no-build-isolation -e .
```

### Install from pypi
Expand Down
112 changes: 54 additions & 58 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@
else:
import importlib.metadata as importlib_metadata

AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
AUTOROUND_MINIMUM_VERSION = version.parse("0.2")


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
try:##TODO remove it later
import auto_round
return True, auto_round.__version__
except:
pass

package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
Expand All @@ -71,26 +77,32 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
return package_exists


_auto_gptq_available = _is_package_available("auto_gptq")
_auto_round_available = _is_package_available("auto_round")


def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
def is_auto_round_available():
if _auto_round_available:
version_autoround = version.parse(importlib_metadata.version("auto_round"))
if AUTOROUND_MINIMUM_VERSION < version_autoround:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq},"
f" but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
f"Found an incompatible version of auto-round. Found version {version_autoround},"
f" but only version above {AUTOROUND_MINIMUM_VERSION} are supported"
)


if is_auto_gptq_available():
from auto_gptq import exllama_set_max_input_length
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
def is_autoround_exllamav2_available():
res = True
try:
from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
except ImportError as e:
res = False
return res


if is_auto_round_available():
from auto_round.export.export_to_autoround.post_init import autoround_post_init


#
Expand Down Expand Up @@ -201,7 +213,7 @@ def __init__(
dataset: str = None,
group_size: int = 128,
sym: bool = False,
backend="gptq:exllamav2",
backend="autoround:exllamav2",
iters: int = 200,
weight_config: dict = None,
enable_quanted_input=True,
Expand Down Expand Up @@ -233,16 +245,12 @@ def __init__(
self.post_init()

def get_loading_attributes(self):
pass
# attibutes_dict = copy.deepcopy(self.__dict__)
# loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
# loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
# return loading_attibutes_dict
return {}

def post_init(self):
r"""Safety checker that arguments are correct."""
if self.bits not in [2, 3, 4, 8]:
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
if self.bits not in [2, 4, 8]:
raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")
##TODO add more check
Expand All @@ -254,23 +262,21 @@ def to_dict(self):


class AutoRoundQuantizer(HfQuantizer):
"""Quantizer of the Autoround method, currently only gptq backend has been supported."""
"""Quantizer of the AutoRound method, currently only triton and exllamav2 backend has been supported."""

requires_calibration = False
required_packages = ["auto_gptq"]
required_packages = ["auto_round"]
optimum_quantizer = None

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.exllama2_available = not is_autoround_exllamav2_available

def validate_environment(self, *args, **kwargs):
gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not is_auto_gptq_available():
raise ImportError("Loading a GPTQ quantized model requires auto-gptq library (`pip install auto-gptq`)")
elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
raise ImportError("You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`")
if not is_auto_round_available():
raise ImportError("Loading a AutoRound quantized model requires auto-round library (`pip install auto-round`)")
elif version.parse(importlib.metadata.version("auto_round")) < version.parse("0.2.0"):
raise ImportError("You need a version of auto_round > 0.2.0 to use AutoRound: `pip install --upgrade auto-round`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
Expand All @@ -280,7 +286,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype

def convert_model(self, model: nn.Module):
"""Convert the model to a GPTQ model by getting and replacing the layers.
"""Convert the model to an AutoRound model by getting and replacing the layers.

Args:
model (`nn.Module`):
Expand Down Expand Up @@ -308,15 +314,22 @@ def convert_model(self, model: nn.Module):
layer_configs[layer_name]["data_type"] = data_type
layer_configs[layer_name]["sym"] = sym
else:
layer_configs[layer_name]["bits"] = extra_config.get("bits", bits)
layer_configs[layer_name]["group_size"] = extra_config.get("group_size", group_size)
layer_configs[layer_name]["data_type"] = extra_config.get("data_type", data_type)
layer_configs[layer_name]["sym"] = extra_config.get("sym", sym)
layer_configs[layer_name]["bits"] = extra_config[layer_name].get("bits", bits)
layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", group_size)
layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type)
layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym)
backend = quantization_config.backend

self._replace_by_quant_layers(model, layer_configs, backend)
return model

def _dynamic_import_inference_linear(self, bits):
if bits == 4 and self.exllama2_available:
from auto_round.export.export_to_autoround.qliner_exllamav2 import QuantLinear
else:
from auto_round.export.export_to_autoround.qliner_triton import QuantLinear
return QuantLinear

def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
"""Replaces linear layers in `module` by `QuantLinear`

Expand All @@ -335,21 +348,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
data_type = config["data_type"]
if not (bits <= 8 and data_type == "int"):
continue
from auto_round.export.export_to_autoround.export_to_autoround import get_autogptq_backend_config

use_triton, disable_exllama, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
)
QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=False,
group_size=group_size,
bits=bits,
disable_exllama=True,
disable_exllamav2=False,
use_qigen=use_qigen,
disable_marlin=disable_marlin,
)
QuantLinear = self._dynamic_import_inference_linear(bits)
layer = get_module(module, layer_name)
device = get_device(layer)
if isinstance(layer, nn.Linear):
Expand Down Expand Up @@ -381,24 +380,21 @@ def post_init_model(self, model):
model (`nn.Module`):
The input model
"""

# if self.bits == 4 and not self.disable_exllama:
#
# if self.bits == 4:
# if get_device(model) == torch.device("cpu") or (
# hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
# ):
# raise ValueError(
# "Found modules on cpu/disk. Using Exllama
# or Exllamav2 backend requires all the modules to be on GPU."
# "You can deactivate exllama backend by
# setting `disable_exllama=True` in the quantization config object"
# "Found modules on cpu/disk. Using Exllamav2 backend requires all the modules to be on GPU."
# "You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
# )

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = False
model = autogptq_post_init(model, use_act_order=False)
model = autoround_post_init(model)
return model

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
from .register import EXPORT_FORMAT
from .export_to_autogptq import save_quantized_as_autogptq
from .export_to_itrex import save_quantized_as_itrex, QuantConfig
from .export_to_autoround.export_to_autoround import save_quantized_as_autoround
from .export_to_autoround.export import save_quantized_as_autoround


2 changes: 1 addition & 1 deletion auto_round/export/export_to_autoround/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .export_to_autoround import save_quantized_as_autoround
from .export import save_quantized_as_autoround

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from auto_round.utils import get_layer_names_in_block, get_block_names, get_module, logger, set_module



def check_neq_config(config, data_type, bits, group_size, sym):
res = []
if data_type != config["data_type"]:
Expand Down Expand Up @@ -53,7 +52,7 @@ def get_autogptq_backend_config(backend, bits=4):
if backend == "gptq:marlin":
use_triton = False
disable_marlin = True
if backend == "gptq:exllamav2":
if backend == "gptq:exllamav2": ##need v1 code to export
use_triton = False
disable_marlin = True
if backend == "gptq:exllamav1":
Expand All @@ -71,10 +70,33 @@ def get_autogptq_backend_config(backend, bits=4):
return use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin


@register_format("autoround")
def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav2", **kwargs):
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
def dynamic_QuantLienar_for_packing(backend, bits, group_size):
if "gptq" in backend:
use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
)
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=False,
group_size=group_size,
bits=bits,
disable_exllama=disable_exllamav1,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
disable_marlin=disable_marlin,
)
return QuantLinear
elif "autoround" in backend or "auto-round" in backend or "auto_round" in backend: ##export all use trition,inferce use exllamav2
from .qliner_triton import QuantLinear
return QuantLinear

else:
assert False, f"only support gptq and autoround backend"


@register_format("autoround")
def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exllamav2", **kwargs):
model = kwargs["model"]
if not inplace:
model = copy.deepcopy(model.to("cpu"))
Expand All @@ -90,22 +112,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav

bits = config["bits"]
group_size = config["group_size"]
use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
)

layer = get_module(model, name)
device = "cpu"
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=False,
group_size=group_size,
bits=bits,
disable_exllama=disable_exllamav1,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
disable_marlin=disable_marlin,
)
device = layer.weight.device

QuantLinear = dynamic_QuantLienar_for_packing(backend, bits, group_size)

if isinstance(layer, nn.Linear):
in_features = layer.in_features
Expand Down Expand Up @@ -138,7 +149,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav
quantization_config["backend"] = backend
extra_config = {}
for layer_name in weight_config:
if weight_config[layer_name]["data_type"] != "int" and weight_config[layer_name]["bits"] >= 16:
if weight_config[layer_name]["bits"] >= 16:
continue
if layer_name not in layer_names_in_block:
extra_config[layer_name] = {}
Expand Down Expand Up @@ -190,7 +201,7 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_ser
"""
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
config_file = "quantize_config.json"
if hasattr(model, "config") and hasattr(model.config, "quantize_config"):
config_file = "quantization_config.json"
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)
Loading
Loading