Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine forward hook #290

Merged
merged 20 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 102 additions & 102 deletions auto_round/autoround.py

Large diffs are not rendered by default.

116 changes: 80 additions & 36 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,48 @@
# limitations under the License.

import torch
from collections import UserDict
special_states_dim_tuple = ("chatglm",) # input_dim is not the default dimension 0
shareable_keywords = ("position_ids", "cache_position", "position_embeddings")
mllms_with_limited_bs = ("llava", "qwen2-vl", "phi3_v", "mllama") # Limitations on batch_size

share_attention_mask_tuple = ("baichuan",)
special_states_dim_tuple = ("chatglm",)
not_share_position_ids_tuple = ("llava", "phi3_v", "qwen2_vl",)
not_share_rotary_pos_emb_tuple = ("qwen2_vl",)
def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs):
"""Checks if the attention mask states of the hidden states are shared in the model.

def to_device(input, device=torch.device("cpu")):
"""Moves input data to the specified device.

Args:
hidden_states (torch.Tensor): The hidden states of the model.
attention_mask (torch.Tensor, optional): The attention mask tensor. Defaults to None.
**kwargs: Additional keyword arguments.
input: The input data to be moved.
device: The target device.

Returns:
bool: True if attention mask is shared in the model, False otherwise.
The input data on the specified device.
"""
if attention_mask is None or not isinstance(hidden_states, torch.Tensor):
return False
is_special = False
for key in share_attention_mask_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return bool(is_special and attention_mask.shape[0] != hidden_states.shape[0])
if input is None:
return None
if isinstance(input, torch.Tensor):
return input.to(device)
if isinstance(input, dict) or isinstance(input, UserDict):
for inp in input.keys():
input[inp] = to_device(input[inp], device)

elif isinstance(input, list) or isinstance(input, tuple):
if len(input) == 0:
return input
input_res = []
for inp in input:
input_res.append(to_device(inp, device))
if isinstance(input, tuple):
input_res = tuple(input_res)
input = input_res

return input


def check_hidden_state_dim(model, positional_args):
"""Checks the dimensionality of the hidden states.
def check_hidden_state_dim(model, positional_inputs):
"""Check the concatenable dimension of hidden states.

Args:
positional_args: The positional arguments.
positional_inputs: The positional arguments.

Returns:
int: 1 if the model type is 'chatglm' and positional arguments are not None, 0 otherwise.
Expand All @@ -53,23 +64,56 @@ def check_hidden_state_dim(model, positional_args):
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return int(is_special and positional_args is not None)
return int(is_special and positional_inputs is not None)


def check_not_share_position_ids(model, **kwargs):
is_special = False
for key in not_share_position_ids_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return bool(is_special)
def special_model_init(model, positional_inputs, inputs):
"""
Initializes special model inputs by adding positional inputs if missing.

Args:
model: The model instance being initialized.
positional_inputs (list): List of positional inputs to add to inputs.
inputs (dict): Dictionary of model inputs.

Modifies:
inputs (dict): Adds "positional_inputs" key if not present.
"""
if "positional_inputs" not in inputs: # for chatglm Series
inputs["positional_inputs"] = []
for idx, item in enumerate(positional_inputs):
inputs["positional_inputs"] = to_device(positional_inputs)

def check_not_share_rotary_pos_emb(model, **kwargs):
is_special = False
for key in not_share_rotary_pos_emb_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return bool(is_special)

def reset_params(inputs):
"""
Resets specific input parameters to avoid saving the key-value cache during fine-tuning.

Args:
inputs (dict): Dictionary of model inputs.

Modifies:
inputs (dict): Sets "use_cache" to False if the key is present.
"""
if "use_cache" in inputs.keys(): # Not storing kv cache
inputs['use_cache'] = False


def skip_keywards_hint(key):
"""
Prints a reminder if a key is not stored during quantization fine-tuning.
"""
if 'past_key_value' not in key:
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
return (f"Please note that this '{key}' key is not currently used in quantization fine-tuning.")
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved


def check_model_batch(model, batch_size, gradient_accumulate_steps):
"""
Checks model configuration to determine if it's necessary to limit bs to avoid potential input shape mismatches.
"""
for key in mllms_with_limited_bs:
if hasattr(model, "config") and key in model.config.model_type and batch_size != 1:
accumulate_steps = batch_size * gradient_accumulate_steps
raise RuntimeError("To avoid the tensor concat mismatch problem, please modify parameters to " \
f"batch_size=1. As an alternative, you can set the gradient_accumulate_steps={accumulate_steps}")

25 changes: 13 additions & 12 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from functools import lru_cache
from packaging import version
import gc
from .special_model_handler import shareable_keywords

@lru_cache(None)
def warning_once(self, msg: str):
Expand Down Expand Up @@ -347,9 +348,7 @@ def collect_best_params(block):

@torch.no_grad()
def sampling_inputs(input_ids, input_others, indices, seqlen,
share_attention_mask_flag=False,
not_share_position_ids_flag=False,
not_share_rotary_pos_emb_flag=False, input_dim=0):
input_dim=0):
"""Samples inputs based on the given indices and sequence length.

Args:
Expand All @@ -366,19 +365,20 @@ def sampling_inputs(input_ids, input_others, indices, seqlen,
current_input_ids = torch.cat(current_input_ids, dim=input_dim)
current_input_others = {"positional_inputs": input_others["positional_inputs"]}
for key in input_others.keys():
if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key) \
or (not_share_position_ids_flag and ("position_ids" in key or \
"cache_position" in key or "position_embeddings" in key)) \
or (not_share_rotary_pos_emb_flag and ("rotary_pos_emb" in key or 'cu_seqlens' in key)) \
or "cross_attention_states" in key:
if "positional_inputs" in key:
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
continue
if (key not in shareable_keywords or len(indices) == 1) \
and not isinstance(input_others[key], (str, bool, type(None))):
current_input_others[key] = None
if input_others[key] is not None:
current_input_others[key] = [input_others[key][i] for i in indices]
if not isinstance(current_input_others[key], torch.Tensor):
if len(current_input_others[key]) == 1:
current_input_others[key] = current_input_others[key][0]
else:
if len(indices) == 1:
current_input_others[key] = current_input_others[key][0]
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
else:
try:
current_input_others[key] = torch.cat(current_input_others[key], dim=0)
except TypeError as err:
logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.")
else:
current_input_others[key] = input_others[key]

Expand Down Expand Up @@ -864,3 +864,4 @@ def clear_memory(tensor=None):
del tensor
gc.collect()
torch.cuda.empty_cache()

5 changes: 3 additions & 2 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
parser.add_argument("--group_size", default=128, type=int,
help="group size")

parser.add_argument("--train_bs", default=8, type=int,
parser.add_argument("--batch_size", default=8, type=int,
help="train batch size")

parser.add_argument("--eval_bs", default=None, type=int,
Expand Down Expand Up @@ -323,7 +323,7 @@
error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization."
raise EnvironmentError(error_message)

autoround = round(model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.train_bs,
autoround = round(model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size,
dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str,
amp=not args.disable_amp, nsamples=args.nsamples,
Expand Down Expand Up @@ -466,3 +466,4 @@
from lm_eval.utils import make_table

print(make_table(res))

3 changes: 3 additions & 0 deletions examples/language-modeling/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ wandb
py-cpuinfo
numpy < 2.0
threadpoolctl
numexpr
bitsandbytes # for baichuan Series

22 changes: 12 additions & 10 deletions examples/multimodal-modeling/Common_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
parser.add_argument("--group_size", default=128, type=int,
help="group size")

parser.add_argument("--train_bs", default=1, type=int,
parser.add_argument("--batch_size", default=1, type=int,
help="train batch size")

parser.add_argument("--eval_bs", default=4, type=int,
Expand All @@ -288,7 +288,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
"allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.")

parser.add_argument("--sym", action='store_true',
help=" sym quantization")
help="sym quantization")

parser.add_argument("--iters", default=200, type=int,
help=" iters")
Expand Down Expand Up @@ -339,6 +339,9 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat

parser.add_argument("--disable_trust_remote_code", action='store_true',
help="Whether to disable trust_remote_code")

parser.add_argument("--not_use_best_mse", action='store_true',
help="To determine whether the quantization should handle vision component.")

parser.add_argument("--disable_quanted_input", action='store_true',
help="whether to disuse the output of quantized block to tune the next block")
Expand Down Expand Up @@ -381,8 +384,8 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
if args.act_bits <= 8 and args.deployment_device != "fake":
assert False, "only support fake mode for activation quantization currently"

if "marlin" in args.deployment_device and args.sym == False:
assert False, "marlin backend only supports sym quantization, please set --sym"
if "marlin" in args.deployment_device and args.asym == True:
assert False, "marlin backend only supports sym quantization, please enable --sym"

model_name = args.model_name
if model_name[-1] == "/":
Expand Down Expand Up @@ -420,9 +423,10 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
questions = json.load(open(args.question_file, "r"))
config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
processor = None
if "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
model = MllamaForConditionalGeneration.from_pretrained(args.model_name, attn_implementation="eager",
model = MllamaForConditionalGeneration.from_pretrained(args.model_name,
trust_remote_code=not args.disable_trust_remote_code) # torch_dtype=torch.bfloat16
processor = AutoProcessor.from_pretrained(args.model_name)
tokenizer.processor = processor
Expand All @@ -448,7 +452,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
raw_data = DataFormating(questions, args.image_folder, model_type=model_type)
dataset = LazySupervisedDataset(raw_data, tokenizer,
max_len=min(args.seqlen, tokenizer.model_max_length), image_folder=args.image_folder)
dataloader = get_train_dataloader(dataset, model, data_collator=default_collator, train_batch_size=args.train_bs)
dataloader = get_train_dataloader(dataset, model, data_collator=default_collator, train_batch_size=args.batch_size)

from auto_round import (AutoRound,
AutoRoundAdam)
Expand Down Expand Up @@ -497,10 +501,10 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat

quant_block_list = get_multimodal_block_names(model, args.quant_vision)

autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs,
autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size,
dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input,
amp=not args.disable_amp, nsamples=args.nsamples,
amp=not args.disable_amp, nsamples=args.nsamples, not_use_best_mse=args.not_use_best_mse,
low_gpu_mem_usage=args.low_gpu_mem_usage, device=device_str,
seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
scale_dtype=args.scale_dtype, layer_config=layer_config,
Expand Down Expand Up @@ -579,5 +583,3 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
)




4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ datasets
py-cpuinfo
sentencepiece
torch
transformers<=4.45.2
transformers
triton
numpy < 2.0
threadpoolctl
lm-eval>=0.4.2,<=0.4.5
tqdm
packaging
auto-gptq>=0.7.1
pillow
pillow
Loading