Skip to content

Commit

Permalink
Unify GPTQ dataloader with fixed/unfixed length data (#1212)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
YIYANGCAI and pre-commit-ci[bot] authored Sep 7, 2023
1 parent cca57d3 commit 6733681
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.functional import pad

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -218,13 +219,6 @@ def skip(*args, **kwargs):
model.eval()

# dataset
# original method of loading data, only load the sequence whose length > model.seqlen
# ================================================
# dataloader, testloader = get_loaders(
# args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model_name_or_path, seqlen=model.seqlen
# )
# dataloader = INCDataloader(dataloader)
# ================================================
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
calib_dataset = load_dataset(args.dataset, split="train") # default
# calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF
Expand All @@ -244,7 +238,6 @@ def skip(*args, **kwargs):

model = model.to(DEV)

print('Starting ...')
if args.sym:
sym_opt = "sym"
else:
Expand Down Expand Up @@ -276,7 +269,8 @@ def skip(*args, **kwargs):
# 'act_order':args.act_order,
# 'block_size': args.block_size,
# 'nsampeles': args.nsamples,
# 'use_max_length': args.use_max_length
# 'use_max_length': args.use_max_length,
# 'pad_max_length': args.pad_max_length
# },
# },
# )
Expand All @@ -296,7 +290,8 @@ def skip(*args, **kwargs):
weight_config=conf,
dataloader=calib_dataloader,
nsamples = args.nsamples,
use_max_length = args.use_max_length
use_max_length = args.use_max_length,
pad_max_length = args.pad_max_length
)

results = lm_evaluate(
Expand Down
8 changes: 7 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,9 +4597,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
}
nsamples = self.recipes["gptq_args"].get("nsamples", 128)
use_max_length = self.recipes["gptq_args"].get("use_max_length", False)
pad_max_length = self.recipes["gptq_args"].get("pad_max_length", 2048)
if use_max_length and "pad_max_length" not in self.recipes["gptq_args"]:
logger.warning(
"You choose to use unified sequence length for calibration, \
but you have not set length value. Default sequence length is 2048 and this might cause inference error!"
)
# tune_cfg => weight_config
model, quantization_perm = gptq_quantize(
model, weight_config, dataloader, nsamples, use_max_length, self.device
model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device
)
return model, quantization_perm

Expand Down
160 changes: 49 additions & 111 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,16 @@ class GPTQuantizer(object):
url: https://arxiv.org/abs/2210.17323
"""

def __init__(self, model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, device=None):
def __init__(
self,
model,
weight_config={},
dataloader=None,
nsamples=128,
use_max_length=True,
pad_max_length=2048,
device=None,
):
"""
Args:
model: the fp32 model to quantize
Expand Down Expand Up @@ -211,44 +220,29 @@ def __init__(self, model, weight_config={}, dataloader=None, nsamples=128, use_m

# dataloader
self.use_max_length = use_max_length
self.pad_max_length = pad_max_length
self.dataloader_original = dataloader
self.dataloader = []
self.nsamples = nsamples
self.prepare_dataloader()

def prepare_dataloader(self):
if self.use_max_length:
# (Recommend) only take sequence whose length exceeds model.seqlen,
# (Recommend) only take sequence whose length exceeds self.pad_max_length,
# which perserves calibration's tokens are all valid
# This is GPTQ official dataloader implementation
self.obtain_first_n_samples_fulllength()
# initialize buffers which are essential for gptq computation.
self.model_hidden_size = 2048
self.initialize_inp_buffersize()
try:
# Since length is unified, we can allocate a continous space to store inputs
self.inp = torch.zeros(
(len(self.dataloader), self.model.seqlen, self.model_hidden_size),
dtype=self.dtype,
device=self.device,
)
self.cache = {"i": 0}
self.out = torch.zeros_like(self.inp)
self.is_ready = True
except:
logger.warning("GPTQ Quantizer initialization failed!")
pass
else:
# general selection, no padding, not GPTQ original implementation.
self.obtain_first_n_samples()
try:
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.cache = {"i": 0}
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.is_ready = True
except:
logger.warning("GPTQ Quantizer initialization failed!")
pass
try:
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.cache = {"i": 0}
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.is_ready = True
except:
logger.warning("GPTQ Quantizer initialization failed!")
pass

def obtain_first_n_samples(self, seed=0):
"""Get first nsample data as the real calibration dataset."""
Expand All @@ -257,12 +251,13 @@ def obtain_first_n_samples(self, seed=0):
for batch in self.dataloader_original:
# process data, depends on its data type.
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list, tuple
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] > self.model.seqlen:
i = random.randint(0, batch[0].shape[-1] - self.model.seqlen - 1)
j = i + self.model.seqlen
if batch[0].shape[-1] > self.pad_max_length:
i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1)
j = i + self.pad_max_length
batch_final = batch[0][:, i:j]
else:
batch_final = batch[0]
Expand All @@ -274,9 +269,9 @@ def obtain_first_n_samples(self, seed=0):
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length > self.model.seqlen:
i = random.randint(0, length - self.model.seqlen - 1)
j = i + self.model.seqlen
if length > self.pad_max_length:
i = random.randint(0, length - self.pad_max_length - 1)
j = i + self.pad_max_length
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
Expand All @@ -287,9 +282,9 @@ def obtain_first_n_samples(self, seed=0):
batch_final = batch
# tensor
else:
if batch.shape[-1] > self.model.seqlen:
i = random.randint(0, batch.shape[-1] - self.model.seqlen - 1)
j = i + self.model.seqlen
if batch.shape[-1] > self.pad_max_length:
i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1)
j = i + self.pad_max_length
batch_final = batch[:, i:j]
else:
batch_final = batch
Expand All @@ -301,9 +296,10 @@ def obtain_first_n_samples(self, seed=0):
def obtain_first_n_samples_fulllength(self, seed=0):
self.dataloader.clear()
random.seed(seed)
unified_length = self.model.seqlen
unified_length = self.pad_max_length
for batch in self.dataloader_original:
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list & tuple
if isinstance(batch, list) or isinstance(batch, tuple):
Expand All @@ -325,11 +321,11 @@ def obtain_first_n_samples_fulllength(self, seed=0):
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length == self.model.seqlen:
if length == self.pad_max_length:
batch_final = batch
elif length > self.model.seqlen:
i = random.randint(0, length - self.model.seqlen - 1)
j = i + self.model.seqlen
elif length > self.pad_max_length:
i = random.randint(0, length - self.pad_max_length - 1)
j = i + self.pad_max_length
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
Expand All @@ -354,53 +350,9 @@ def obtain_first_n_samples_fulllength(self, seed=0):
if len(self.dataloader) < self.nsamples: # pragma: no cover
logger.warning(
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
but only {len(self.dataloader)} samples satisfy your setting. You may choose smaller 'model.seqlen' value."
but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value."
)

@torch.no_grad()
def initialize_inp_buffersize(self):
# Run a forward and generate proper buffer tensor
# Thus, no need to pass hidden_states dimension parameters of model.config
# e.g. OPT's hidden_states dimension can be called by model.config.hidden_size
# but mpt's hidden_states dimension can be called by model.config.d_model
def forward(layer, hidden_states, **kwargs):
# inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1
logger.info(f"The hidden_states shape along transformers blocks is {hidden_states.shape}.")
self.model_hidden_size = hidden_states.shape[-1]
raise ValueError

# Step1: fetch the embeddings and other layers before the transformer stack.
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer = embedding_layer.to(self.device)

# Step2: modify the first transformer block's forward function to obtain inputs for calibration
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
)

# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in self.dataloader:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
elif isinstance(batch, dict):
self.model(**batch)
else:
self.model(batch.to(self.device))
except ValueError:
break

# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
embedding_layer.to(self.device)
torch.cuda.empty_cache()

def get_full_layer_name(self, sub_layer_name, block_idx):
transformer_name = self.gptq_related_blocks["transformers_name"]
return ".".join([transformer_name, str(block_idx), sub_layer_name])
Expand Down Expand Up @@ -459,18 +411,12 @@ def forward(layer, hidden_states, **kwargs):
self.cache["i"] += 1
for arg in kwargs:
# TODO: investigate include parameters
if self.use_max_length:
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
self.cache[arg] = kwargs[arg]
else:
continue
else:
# each outputs can be different shape, hence also use list to store
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
if self.cache.get(arg, None) is None:
self.cache[arg] = []
self.cache[arg].append(kwargs[arg])
continue
# each outputs can be different shape, hence also use list to store
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
if self.cache.get(arg, None) is None:
self.cache[arg] = []
self.cache[arg].append(kwargs[arg])
continue
raise ValueError

# Step1: fetch the embeddings and other layers before the transformer stack.
Expand Down Expand Up @@ -572,13 +518,9 @@ def tmp(_, inp, out):
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
idx = self.cache.pop("i")
for j in range(len(self.dataloader)):
if self.use_max_length:
# self.inp[j] shape: [seq_len, hidden_size]
self.out[j] = transformer_block(self.inp[j].unsqueeze(0), **self.cache)[0]
else:
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
self.cache["i"] = idx
for h in handles:
h.remove()
Expand Down Expand Up @@ -607,13 +549,9 @@ def tmp(_, inp, out):
# Step 2.5: replace output data with quantized weights
idx = self.cache.pop("i")
for j in range(len(self.dataloader)):
if self.use_max_length:
# self.inp[j] shape: [seq_len, hidden_size]
self.out[j] = transformer_block(self.inp[j].unsqueeze(0), **self.cache)[0]
else:
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
self.cache["i"] = idx
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
del gptq_for_this_block
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,15 @@ def rtn_quantize(
return model


def gptq_quantize(model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, device=None):
def gptq_quantize(
model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
from .gptq import GPTQuantizer

gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, device)
gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device)
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
logger.info("GPTQ quantizing done.")
return fp32_modified_model, gptq_config
Expand Down
5 changes: 2 additions & 3 deletions test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def setUpClass(self):
self.gptj_no_jit = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)
self.gptj.seqlen = 512
self.llm_dataloader = LLMDataLoader()
self.lm_input = torch.ones([1, 10], dtype=torch.long)

Expand Down Expand Up @@ -502,7 +501,7 @@ def __iter__(self):
},
},
recipes={
"gptq_args": {"percdamp": 0.01, "act_order": False},
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": True, "pad_max_length": 512},
},
)

Expand Down Expand Up @@ -608,7 +607,7 @@ def __iter__(self):
},
},
recipes={
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": True},
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": False, "pad_max_length": 512},
},
)

Expand Down
Loading

0 comments on commit 6733681

Please sign in to comment.