Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quant_block_names enhancement #369

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod

from auto_round.utils import get_module, set_module, is_hpu_supported
from auto_round.utils import (get_module, set_module, is_hpu_supported, get_block_names,
get_multimodal_block_names, find_matching_blocks)

from auto_round.backend import get_layer_backend, dynamic_import_inference_linear

Expand Down Expand Up @@ -409,7 +410,15 @@ def convert_model(self, model: nn.Module):
sym = quantization_config.sym
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
"to_quant_block_names") else None
layer_names = get_layer_names_in_block(model, to_quant_block_names=to_quant_block_names)
quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config,
"quant_block_list") else None
if to_quant_block_names is None: # TODO check compatibility
all_blocks = get_block_names(model)
else:
all_blocks = get_multimodal_block_names(model, quant_vision=True)
if quant_block_list is None:
quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list)

extra_config = {}
if hasattr(quantization_config, "extra_config"):
Expand Down Expand Up @@ -734,3 +743,4 @@ def is_serializable(self):

transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer

19 changes: 11 additions & 8 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ def __init__(
self.device = detect_device(device)
self.scale_dtype = convert_dtype_str2torch(scale_dtype)
self.set_amp_dtype()

self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
if not hasattr(self, 'to_quant_block_names'):
self.to_quant_block_names = to_quant_block_names
if not hasattr(self, 'quant_block_list'):
all_blocks = get_block_names(model)
self.to_quant_block_names = find_matching_blocks(model, all_blocks, to_quant_block_names)
self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names)
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device


##activation
self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size
Expand Down Expand Up @@ -281,8 +282,8 @@ def quantize(self):
The quantized model and layer configurations.
"""

if bool(self.to_quant_block_names):
all_blocks = self.to_quant_block_names
if bool(self.quant_block_list):
all_blocks = self.quant_block_list
else:
all_blocks = get_block_names(self.model)

Expand Down Expand Up @@ -434,7 +435,7 @@ def set_layerwise_config(self, layer_config):
Returns:
None
"""
layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.to_quant_block_names)
layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list)
keys = ["data_type", "bits", "group_size", "sym", "scale_dtype", "act_bits", "act_group_size", "act_sym",
"act_dynamic", "act_data_type"]
for n, m in self.model.named_modules():
Expand Down Expand Up @@ -1333,6 +1334,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
serialization_dict=serialization_dict,
backend=backend,
to_quant_block_names=self.to_quant_block_names,
quant_block_list=self.quant_block_list,
**kwargs
)
return compressed_model
Expand All @@ -1347,7 +1349,7 @@ def get_quantized_layer_names_outside_blocks(self):
return []

layer_names = []
all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.to_quant_block_names)
all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list)

for key in self.layer_config.keys():
if key in all_layers_in_block:
Expand Down Expand Up @@ -1735,3 +1737,4 @@ def __init__(
optimizer=optimizer,
**kwargs,
)

6 changes: 4 additions & 2 deletions auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
supported_types = kwargs["supported_types"]
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
to_quant_block_names = kwargs["to_quant_block_names"]
quant_block_list = kwargs.get("quant_block_list", None)
logger.info("Saving quantized model to autogptq format, this may take a while...")
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
Expand All @@ -130,8 +131,8 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
processor.save_pretrained(output_dir)
##check module quantized in block, this may have bug for mixed precision quantization
quantization_config = kwargs["serialization_dict"]
if bool(to_quant_block_names):
all_blocks = to_quant_block_names
if bool(quant_block_list):
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
all_blocks = quant_block_list
flattened_list = [item for sublist in all_blocks for item in sublist]
common_prefix = os.path.commonprefix(flattened_list).rstrip('.')
if common_prefix not in BLOCK_PATTERNS:
Expand Down Expand Up @@ -220,3 +221,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)


4 changes: 3 additions & 1 deletion auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex

model = kwargs["model"]
to_quant_block_names = kwargs["to_quant_block_names"]
quant_block_list = kwargs.get("quant_block_list", None)
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
if not inplace:
model = copy.deepcopy(model.to("cpu"))
layer_names_in_block = get_layer_names_in_block(model, to_quant_block_names=to_quant_block_names)
layer_names_in_block = get_layer_names_in_block(model, quant_block_list=quant_block_list)

layer_config = kwargs["layer_config"]
quantization_config = kwargs["serialization_dict"]
Expand Down Expand Up @@ -279,3 +280,4 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
json.dump(model.config.quantization_config, f, indent=2)



9 changes: 7 additions & 2 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
to_device,
to_dtype,
get_multimodal_block_names,
find_matching_blocks
find_matching_blocks,
extract_block_names_to_str
)
from ..autoround import AutoRound
from .template import get_template, Template
Expand Down Expand Up @@ -143,7 +144,10 @@ def __init__(
**kwargs,
):
all_blocks = get_multimodal_block_names(model, quant_nontext_module)
self.to_quant_block_names = find_matching_blocks(model, all_blocks, to_quant_block_names)
self.quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
if to_quant_block_names is None:
to_quant_block_names = extract_block_names_to_str(self.quant_block_list)
self.to_quant_block_names = to_quant_block_names
self.extra_data_dir = extra_data_dir
self.quant_nontext_module = quant_nontext_module
self.image_processor = image_processor
Expand Down Expand Up @@ -368,3 +372,4 @@ def calib(self, nsamples, bs):
m = m.to("meta")
# torch.cuda.empty_cache()


29 changes: 23 additions & 6 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def validate_modules(module_names, quant_vision=False, vison_blocks_names=None):
"or raise an issue at https://github.com/intel/auto-round/issues.")
return

def get_common_prefix(paths):
# Split each path into components and find the common prefix
split_paths = [path.split('.') for path in paths]
common_prefix = split_paths[0]
for path in split_paths[1:]:
common_prefix = [comp for comp, other in zip(common_prefix, path) if comp == other]
return '.'.join(common_prefix)

def extract_block_names_to_str(quant_block_list):
if not isinstance(quant_block_list, (list,tuple)):
return None
# Extract common prefix for each list
prefixes = [get_common_prefix(blocks) for blocks in quant_block_list]
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
# Join prefixes into a single string
return ','.join(prefixes)


def find_matching_blocks(model, all_blocks, to_quant_block_names):
"""
Expand Down Expand Up @@ -347,9 +363,9 @@ def find_matching_blocks(model, all_blocks, to_quant_block_names):
matched_sublist.extend(matches)
if matched_sublist:
target_blocks.append(matched_sublist)
if not target_blocks:
raise ValueError("No block names matched. Please check the input for to_quant_block_name," \
"or set to_quant_block_name to None to automatically match quantizable blocks.")
if not target_blocks:
raise ValueError("No block names matched. Please check the input for to_quant_block_name," \
"or set to_quant_block_name to None to automatically match quantizable blocks.")
return target_blocks


Expand Down Expand Up @@ -776,7 +792,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs):


def get_layer_names_in_block(model, supported_types=[torch.nn.Linear,
transformers.modeling_utils.Conv1D], to_quant_block_names=None):
transformers.modeling_utils.Conv1D], quant_block_list=None):
"""Retrieves the names of layers within each block of the model.

Returns:
Expand All @@ -787,8 +803,8 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear,
if isinstance(m, tuple(supported_types)):
m.tmp_name = n
layers_in_block = []
if bool(to_quant_block_names):
all_blocks = to_quant_block_names
if bool(quant_block_list):
all_blocks = quant_block_list
else:
all_blocks = get_block_names(model)
for block_names in all_blocks:
Expand Down Expand Up @@ -1081,3 +1097,4 @@ def get_fp_layer_names(model, fp_layers):
not_to_quantized_layers.append(name)

return not_to_quantized_layers

18 changes: 18 additions & 0 deletions test/test_cuda_before_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,21 @@ def test_undivided_group_size_tuning(self):

autoround = AutoRound(model, tokenizer, bits=4, group_size=128, nsamples=1, iters=1)
autoround.quantize()

@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_vision_generation(self):
quantized_model_path = "OPEA/Phi-3.5-vision-instruct-qvision-int4-sym-inc"
from auto_round import AutoRoundConfig
device = "auto" ##cpu, hpu, cuda
quantization_config = AutoRoundConfig(
backend=device
)
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True,
device_map=device, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])
print(res)
assert (
res == """<s> There is a girl who likes adventure, and she is looking for a partner to go on a treasure hunt. She has found a map that leads to a hidden treasure, but she needs a partner to help her decipher the clues and find the treasure. You""")
12 changes: 4 additions & 8 deletions test/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_llm_generation_sym_gpu_gptq(self):
if not torch.cuda.is_available():
return
bits = 4
group_size = 32
autoround = AutoRound(
Expand Down Expand Up @@ -67,11 +66,9 @@ def test_llm_generation_sym_gpu_gptq(self):
assert (
res == """</s>There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""")

# #
#
# @unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
# def test_llm_generation_sym_gpu_gptq_marlin(self): ##need auto_gptq >0.7.1
# if not torch.cuda.is_available():
# return
# bits = 4
# group_size = 128
# autoround = AutoRound(
Expand Down Expand Up @@ -103,10 +100,8 @@ def test_llm_generation_sym_gpu_gptq(self):
# assert (
# res == """</s>There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""")


@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
def test_llm_generation_asym_gpu_awq(self):
if not torch.cuda.is_available():
return
bits = 4
group_size = 32
autoround = AutoRound(
Expand Down Expand Up @@ -175,3 +170,4 @@ def test_llm_generation_asym_qbits(self):




Loading