Skip to content

Commit

Permalink
avoid deterministic algorithm warning in inference (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Oct 22, 2024
1 parent 141c149 commit ba5be40
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ release most of the models ourselves.
| meta-llama/Meta-Llama-3.1-8B-Instruct | [model-kaitchup-autogptq-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-Instruct-autoround-gptq-4bit-asym), [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-Instruct-autoround-gptq-4bit-sym), [recipe](https://huggingface.co/Intel/Meta-Llama-3.1-8B-Instruct-int4-inc) |
| meta-llama/Meta-Llama-3.1-8B | [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Meta-Llama-3.1-8B-autoround-gptq-4bit-sym) |
| Qwen/Qwen-VL | [accuracy](./examples/multimodal-modeling/Qwen-VL/README.md), [recipe](./examples/multimodal-modeling/Qwen-VL/run_autoround.sh)
| Qwen/Qwen2-7B | [model-autoround-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc) |
| Qwen/Qwen2-57B-A14B-Instruct | [model-autoround-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc) |
| Qwen/Qwen2-7B | [model-autoround-sym-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc), [model-autogptq-sym-int4](https://huggingface.co/Intel/Qwen2-7B-int4-inc) |
| Qwen/Qwen2-57B-A14B-Instruct | [model-autoround-sym-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc),[model-autogptq-sym-int4](https://huggingface.co/Intel/Qwen2-57B-A14B-Instruct-int4-inc) |
| 01-ai/Yi-1.5-9B | [model-LnL-AI-autogptq-int4*](https://huggingface.co/LnL-AI/Yi-1.5-9B-4bit-gptq-autoround) |
| 01-ai/Yi-1.5-9B-Chat | [model-LnL-AI-autogptq-int4*](https://huggingface.co/LnL-AI/Yi-1.5-9B-Chat-4bit-gptq-autoround) |
| Intel/neural-chat-7b-v3-3 | [model-autogptq-int4](https://huggingface.co/Intel/neural-chat-7b-v3-3-int4-inc) |
Expand All @@ -283,7 +283,7 @@ release most of the models ourselves.
| google/gemma-2b | [model-autogptq-int4](https://huggingface.co/Intel/gemma-2b-int4-inc) |
| tiiuae/falcon-7b | [model-autogptq-int4-G64](https://huggingface.co/Intel/falcon-7b-int4-inc) |
| sapienzanlp/modello-italia-9b | [model-fbaldassarri-autogptq-int4*](https://huggingface.co/fbaldassarri/modello-italia-9b-autoround-w4g128-cpu) |
| microsoft/phi-2 | [model-autogptq-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) |
| microsoft/phi-2 | [model-autoround-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) [model-autogptq-sym-int4](https://huggingface.co/Intel/phi-2-int4-inc) |
| microsoft/Phi-3.5-mini-instruct | [model-kaitchup-autogptq-sym-int4*](https://huggingface.co/kaitchup/Phi-3.5-Mini-instruct-AutoRound-4bit) |
| microsoft/Phi-3-vision-128k-instruct | [recipe](./examples/multimodal-modeling/Phi-3-vision/run_autoround.sh)
| mistralai/Mistral-7B-Instruct-v0.2 | [accuracy](./docs/Mistral-7B-Instruct-v0.2-acc.md), [recipe](./examples/language-modeling/scripts/Mistral-7B-Instruct-v0.2.sh), [example](./examples/language-modeling/) |
Expand Down
7 changes: 3 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
import os
import torch
import transformers

torch.use_deterministic_algorithms(True, warn_only=True)

import copy
import time
from typing import Optional, Union

from transformers import set_seed
from torch import autocast
from tqdm import tqdm
from .calib_dataset import get_dataloader

from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer, \
WrapperTransformerConv1d
from .special_model_handler import (check_hidden_state_dim,
Expand Down Expand Up @@ -488,8 +485,10 @@ def calib(self, nsamples, bs):
nsamples (int): The number of samples to use for calibration.
bs (int): The number of samples to use for calibration
"""
from .calib_dataset import get_dataloader
if isinstance(self.dataset, str):
dataset = self.dataset.replace(" ", "") ##remove all whitespaces

# slow here
self.dataloader = get_dataloader(
self.tokenizer,
Expand Down
38 changes: 20 additions & 18 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import random

import torch

torch.use_deterministic_algorithms(True, warn_only=True)
from torch.utils.data import DataLoader

from .utils import is_local_path, logger
Expand Down Expand Up @@ -58,15 +60,15 @@ def default_tokenizer_function(examples, apply_template=apply_template):
if not apply_template:
example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
else:
from jinja2 import Template # pylint: disable=E0401
from jinja2 import Template # pylint: disable=E0401
chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \
else tokenizer.default_chat_template
template = Template(chat_template)
rendered_messages = []
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
return example
Expand Down Expand Up @@ -103,11 +105,11 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split

@register_dataset("madao33/new-title-chinese")
def get_new_chinese_title_dataset(
tokenizer,
seqlen,
dataset_name="madao33/new-title-chinese",
split=None,
seed=42,
tokenizer,
seqlen,
dataset_name="madao33/new-title-chinese",
split=None,
seed=42,
apply_template=False
):
"""Returns a dataloader for the specified dataset and split.
Expand Down Expand Up @@ -148,7 +150,7 @@ def default_tokenizer_function(examples, apply_template=apply_template):
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
return example
Expand Down Expand Up @@ -267,12 +269,12 @@ def load_local_data(data_path):


def get_dataloader(
tokenizer,
seqlen,
dataset_name="NeelNanda/pile-10k",
seed=42,
bs=8,
nsamples=512,
tokenizer,
seqlen,
dataset_name="NeelNanda/pile-10k",
seed=42,
bs=8,
nsamples=512,
):
"""Generate a DataLoader for calibration using specified parameters.
Expand All @@ -293,6 +295,7 @@ def get_dataloader(
"""

dataset_names = dataset_name.split(",")

def filter_func(example):
if isinstance(example["input_ids"], list):
example["input_ids"] = torch.tensor(example["input_ids"])
Expand All @@ -316,7 +319,7 @@ def concat_dataset_element(dataset):
input_id = input_id[1:]
os_cnt, have_bos = os_cnt + 1, True
if input_id[-1] == eos_token_id:
input_id = input_id[:-1]
input_id = input_id[:-1]
os_cnt, have_eos = os_cnt + 1, True

if buffer_input_id.shape[-1] + input_id.shape[-1] + os_cnt > seqlen:
Expand All @@ -326,7 +329,7 @@ def concat_dataset_element(dataset):
input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append
if have_eos:
input_id_to_append.append(torch.tensor([eos_token_id]))

concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64))
attention_mask_list.append(attention_mask)
buffer_input_id = input_id[idx_keep:]
Expand Down Expand Up @@ -405,7 +408,7 @@ def concat_dataset_element(dataset):
name = dataset_names[i].split(':')[0]
if name not in data_lens:
target_cnt = (nsamples - cnt) // (len(datasets) - len(data_lens)) if data_lens \
else (nsamples - cnt) // (len(datasets) - i)
else (nsamples - cnt) // (len(datasets) - i)
target_cnt = min(target_cnt, len(datasets[i]))
cnt += target_cnt
else:
Expand Down Expand Up @@ -447,4 +450,3 @@ def collate_batch(batch):

calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader

0 comments on commit ba5be40

Please sign in to comment.