Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-tokenizer will be called within load() #996

Merged
merged 20 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,28 +137,31 @@ git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel
pip install -v . --no-build-isolation
```

### Quantization and Inference
### Inference
Two line api to use `GPTQModel` for gptq model inference:

Below is a basic sample using `GPTQModel` to quantize a llm model and perform post-quantization inference:
```py
from gptqmodel import GPTQModel

model = GPTQModel.load("ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5")
result = model.generate("Uncovering deep insights begins with")[0]
```

### Quantization
Basic example of using `GPTQModel` to quantize a llm model:

```py
from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"

tokenizer = AutoTokenizer.from_pretrained(model_id)

calibration_dataset = [
tokenizer(example["text"])
for example in load_dataset(
calibration_dataset = load_dataset(
"allenai/c4",
data_files="en/c4-train.00001-of-01024.json.gz",
split="train"
).select(range(1024))
]
).select(range(1024))["text"]

quant_config = QuantizeConfig(bits=4, group_size=128)

Expand All @@ -169,13 +172,9 @@ model.quantize(calibration_dataset, batch_size=2)

model.save(quant_path)

# test post-quant inference
model = GPTQModel.load(quant_path)

result = model.generate(
**tokenizer(
"Uncovering deep insights begins with", return_tensors="pt"
).to(model.device)
)[0]
result = model.generate("Uncovering deep insights begins with")[0]
```

For more advanced features of model quantization, please reference to [this script](https://github.com/ModelCloud/GPTQModel/blob/main/examples/quantization/basic_usage_wikitext2.py)
Expand Down
64 changes: 45 additions & 19 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
move_to,
nested_move_to,
pack_model,
normalize_tokenizer,
MODALITY,
)
from ..utils.progress import ProgressBar
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
model: PreTrainedModel,
quantized: bool,
quantize_config: QuantizeConfig,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
trust_remote_code: bool = False,
Expand All @@ -104,6 +106,7 @@ def __init__(
self.model = model
self.quantized = quantized
self.load_quantized_model = load_quantized_model
self.tokenizer = tokenizer
self.quantize_config = quantize_config
self.config = self.model.config

Expand All @@ -120,10 +123,33 @@ def __init__(

def prepare_dataset(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[List[int]]],
batch_size: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
):
if isinstance(calibration_dataset[0], (str, list)) or (isinstance(calibration_dataset[0], list) and all(isinstance(x, int) for x in calibration_dataset[0])):
if self.tokenizer is None:
raise ValueError(f"tokenizer must be provided when calibration_dataset is List[str] or List[int], type: {type(calibration_dataset[0])}")

# Convert strings/ints to tokenized format
new_calibration_dataset = []
for data in calibration_dataset:
# convert to tensor directly if already in token ids format (ints)
if isinstance(data, list) and all(isinstance(x, int) for x in data):
input_ids = torch.tensor([data], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
new_calibration_dataset.append({
"input_ids": input_ids,
"attention_mask": attention_mask
})
# call tokenizer if dataset still string format (str)
else:
tokenized = self.tokenizer(data, return_tensors="pt")
new_calibration_dataset.append({
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"]
})
calibration_dataset = new_calibration_dataset

def _convert_tensor_to_list(tensor):
if isinstance(tensor, torch.Tensor):
if len(tensor.shape) == 1:
Expand All @@ -136,26 +162,18 @@ def _convert_tensor_to_list(tensor):
for example in calibration_dataset:
input_ids = _convert_tensor_to_list(example["input_ids"])
attention_mask = _convert_tensor_to_list(example["attention_mask"])
if "labels" in example:
labels = _convert_tensor_to_list(example["labels"])
elif "label" in example:
labels = _convert_tensor_to_list(example["label"])
elif "label_ids" in example:
labels = _convert_tensor_to_list(example["label_ids"])
else:
labels = copy.deepcopy(input_ids)

new_calibration_dataset.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)

pad_token_id = self.config.pad_token_id
if not pad_token_id:
if tokenizer:
vocab = tokenizer.get_vocab()
if self.tokenizer:
vocab = self.tokenizer.get_vocab()

# auto select the best pad token to use
for token in ["<|finetune_right_pad_id|>", "<|pad|>", "<pad>", "<|unk|>", "<unk>"]:
Expand All @@ -179,14 +197,12 @@ def _convert_tensor_to_list(tensor):
for start in range(0, len(new_calibration_dataset), batch_size)
]

for new_example in new_calibration_dataset_batched:
del new_example["labels"]

return new_calibration_dataset_batched

def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]],
batch_size: int = 1,
calibration_enable_gpu_cache: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
Expand Down Expand Up @@ -243,6 +259,12 @@ def quantize(
format=self.quantize_config.format,
)

# Use the provided tokenizer if one is passed to quantize()
if tokenizer is not None:
self.tokenizer = tokenizer
# after tokenizer is reset, need to normalize it again
self.tokenizer = normalize_tokenizer(self.config, self.tokenizer)

min_calibration_dataset_size = 256
min_calibration_dataset_input_ids_avg_length = 256

Expand All @@ -255,7 +277,7 @@ def quantize(
if BITBLAS_AVAILABLE is False:
raise ValueError(BITBLAS_INSTALL_HINT)

calibration_dataset = self.prepare_dataset(calibration_dataset, batch_size, tokenizer,)
calibration_dataset = self.prepare_dataset(calibration_dataset, batch_size,)

# Calculate the average length of the average input_ids
total_input_ids_length = 0
Expand Down Expand Up @@ -713,9 +735,13 @@ def to(self, device: Union[str, torch.device]):
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def generate(self, **kwargs):
def generate(self, inputs=None, **kwargs):
with torch.inference_mode():
return self.model.generate(**kwargs)
if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)):
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device)
return self.model.generate(**inputs, **kwargs)

return self.model.generate(inputs=inputs, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
Expand Down
17 changes: 16 additions & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import transformers
from packaging.version import InvalidVersion, Version
from transformers import AutoConfig, PretrainedConfig
from transformers import AutoConfig, PretrainedConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers

Expand All @@ -35,6 +35,7 @@
simple_dispatch_model,
verify_model_hash,
verify_sharded_model_hashes,
normalize_tokenizer,
)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -94,6 +95,14 @@ def get_model_local_path(pretrained_model_id_or_path, **kwargs):
return pretrained_model_id_or_path
else:
return snapshot_download(pretrained_model_id_or_path, **kwargs)

def get_tokenizer(model_id_or_path, config, trust_remote_code: bool = False):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
return normalize_tokenizer(config, tokenizer)
except Exception as e:
logger.warning(f"Failed to auto-load tokenizer from pretrained_model_id_or_path: {e}. Please pass a tokenizer to `quantize()` or set model.tokenizer after `load()`.")
return None


def ModelLoader(cls):
Expand Down Expand Up @@ -178,10 +187,13 @@ def skip(*args, **kwargs):
model.seqlen = 4096
model.eval()

tokenizer = get_tokenizer(pretrained_model_id_or_path, config=config, trust_remote_code=trust_remote_code)

return cls(
model,
quantized=False,
quantize_config=quantize_config,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
model_local_path=model_local_path,
)
Expand Down Expand Up @@ -540,10 +552,13 @@ def skip(*args, **kwargs):

model.eval()

tokenizer = get_tokenizer(model_id_or_path, config=config, trust_remote_code=trust_remote_code)

return cls(
model,
quantized=True,
quantize_config=quantize_config,
tokenizer=tokenizer,
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
trust_remote_code=trust_remote_code,
Expand Down
7 changes: 0 additions & 7 deletions gptqmodel/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,19 @@ def pad_batch(block: LongTensor, pads: Tensor):

input_ids = [LongTensor(block["input_ids"]) for block in batch]
attention_masks = [LongTensor(block["attention_mask"]) for block in batch]
label_blocks = [LongTensor(block["labels"]) for block in batch]

inp_max_len = max([block.size(-1) for block in input_ids])
label_max_len = max([block.size(-1) for block in label_blocks])

for i in range(len(batch)):
block_bsz, block_inp_len = input_ids[i].shape
block_label_len = label_blocks[i].shape[-1]
pad_num = inp_max_len - block_inp_len
if pad_num > 0:
input_ids[i] = pad_batch(input_ids[i], torch.ones((block_bsz, pad_num)) * pad_token_id)
attention_masks[i] = pad_batch(attention_masks[i], torch.zeros((block_bsz, pad_num)))
label_pad_num = label_max_len - block_label_len
if label_pad_num > 0:
label_blocks[i] = pad_batch(label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100)

return {
"input_ids": torch.cat(input_ids, dim=0).long(),
"attention_mask": torch.cat(attention_masks, dim=0).long(),
"labels": torch.cat(label_blocks, dim=0).long(),
}


Expand Down
30 changes: 30 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,36 @@ def check_requires_version(requires_version, current_version):
return OPERATOR_MAP[op_symbol](current_version, required_version)
else:
return None

def normalize_tokenizer(config, tokenizer):
pad_token_id = config.pad_token_id
if not pad_token_id:
if tokenizer:
vocab = tokenizer.get_vocab()

# auto select the best pad token to use
for token in ["<|finetune_right_pad_id|>", "<|pad|>", "<pad>", "<|unk|>", "<unk>"]:
token_id = vocab.get(token)
if token_id is not None:
pad_token_id = token_id
break
else:
logger.warning(
"Model config does not have pad token mapped. Please pass in tokenizer to `quantize()` so GPTQModel can auto-select the best pad token.")

if not pad_token_id and isinstance(config.eos_token_id,
list): # Llama-3.1-8B-Instruct's eos_token_id is a list
pad_token_id = config.eos_token_id[0]
elif not pad_token_id:
pad_token_id = config.eos_token_id

if pad_token_id is None:
raise ValueError(
"Calibration data requires model's `pad_token_id` or `eos_token_id` to be set: actual = `None`.")

tokenizer.pad_token_id = pad_token_id

return tokenizer

class MODALITY(str, Enum):
TEXT = "text"
Expand Down
Loading