Skip to content

Commit

Permalink
GPTQ Support Dict-type Calibration Data (#1178)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
  • Loading branch information
YIYANGCAI authored Aug 23, 2023
1 parent 8facee9 commit 3018319
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 14 deletions.
94 changes: 86 additions & 8 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,35 @@
from functools import partial
from ...utils import logger
import random
from collections import UserDict, defaultdict

DEBUG = False

# ================ device related ===================
def move_input_to_device(input, device=torch.device('cpu')):
if isinstance(input, dict) or isinstance(input, UserDict):
for inp in input.keys():
input[inp] = input[inp].to(device) \
if isinstance(input[inp], torch.Tensor) else input[inp]
elif isinstance(input, list) or isinstance(input, tuple):
input_res, prev_size = [], None
for inp in input:
if prev_size:
if isinstance(inp, torch.Tensor):
if inp.size() == prev_size:
input_res.append(inp.to(device))
else:
if torch.tensor(inp).size == prev_size:
input_res.append(inp)
else:
input_res.append(inp.to(device) \
if isinstance(inp, torch.Tensor) else inp)
prev_size = torch.tensor(inp).size()
input = input_res
else:
input = input.to(device) # pylint: disable=no-member
return input

# ==============model structure related==============
def is_leaf(module):
"""Judge whether a module has no child-modules.
Expand Down Expand Up @@ -232,6 +258,7 @@ def obtain_first_n_samples(self, seed=0):
self.dataloader.clear()
random.seed(seed)
for batch in self.dataloader_original:
# process data, depends on its data type.
if len(self.dataloader) == self.nsamples:
break
# list, tuple
Expand All @@ -242,6 +269,25 @@ def obtain_first_n_samples(self, seed=0):
batch_final = batch[0][:, i:j]
else:
batch_final = batch[0]
# dict
elif isinstance(batch, dict):
try:
length = batch['input_ids'].shape[-1]
except:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length > self.model.seqlen:
i = random.randint(0, length - self.model.seqlen - 1)
j = i + self.model.seqlen
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim
else:
batch_final[key] = batch[key]
else:
batch_final = batch
# tensor
else:
if batch.shape[-1] > self.model.seqlen:
Expand All @@ -251,6 +297,7 @@ def obtain_first_n_samples(self, seed=0):
else:
batch_final = batch
self.dataloader.append(batch_final)

if len(self.dataloader) < self.nsamples:
logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.")

Expand All @@ -264,24 +311,49 @@ def obtain_first_n_samples_fulllength(self, seed=0):
# list & tuple
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] == unified_length:
inp = batch[0]
batch_final = batch[0]
elif batch[0].shape[-1] > unified_length:
i = random.randint(0, batch[0].shape[-1] - unified_length - 1)
j = i + unified_length
inp = batch[0][:, i:j]
batch_final = batch[0][:, i:j]
else:
# not match max length, not include in target dataset
continue
self.dataloader.append(batch_final)
# dict
elif isinstance(batch, dict):
try:
length = batch['input_ids'].shape[-1]
except:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length == self.model.seqlen:
batch_final = batch
elif length > self.model.seqlen:
i = random.randint(0, length - self.model.seqlen - 1)
j = i + self.model.seqlen
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position
else:
batch_final[key] = batch[key]
else:
# not match max length, not include in target dataset
continue
# tensor
else:
if batch.shape[-1] == unified_length:
inp = batch[0]
batch_final = batch
elif batch.shape[-1] > unified_length:
i = random.randint(0, batch.shape[-1] - unified_length - 1)
j = i + unified_length
inp = batch[:, i:j]
batch_final = batch[:, i:j]
else:
# not match max length, not include in target dataset
continue
self.dataloader.append(inp)
self.dataloader.append(batch_final)
if len(self.dataloader) < self.nsamples: # pragma: no cover
logger.warning(f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
but only {len(self.dataloader)} samples satisfy your setting. You may choose smaller 'model.seqlen' value.")
Expand Down Expand Up @@ -311,9 +383,12 @@ def forward(layer, hidden_states, **kwargs):
# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in self.dataloader:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0].to(self.device))
self.model(batch[0])
elif isinstance(batch, dict):
self.model(**batch)
else:
self.model(batch.to(self.device))
except ValueError:
Expand Down Expand Up @@ -410,11 +485,14 @@ def forward(layer, hidden_states, **kwargs):
# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in tqdm(self.dataloader):
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0].to(self.device))
self.model(batch[0])
elif isinstance(batch, dict):
self.model(**batch)
else:
self.model(batch.to(self.device))
self.model(batch)
except ValueError:
pass
# output inp data shape
Expand Down
123 changes: 117 additions & 6 deletions test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,26 +375,34 @@ def __iter__(self):
self.assertTrue(torch.allclose(values['fc1']['output'][0], values['fc2']['input'][0]))
self.assertTrue(torch.allclose(values['fc2']['output'][0], out))


def test_GPTQ_quant(self):
def test_GPTQ_fixed_length_quant(self):
class GPTQLLMDataLoader():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(2):
for i in range(10):
yield torch.ones([1, 512], dtype=torch.long)

class GPTQLLMDataLoaderList():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(2):
for i in range(10):
yield (torch.ones([1, 512], dtype=torch.long), torch.ones([1, 512], dtype=torch.long))

class GPTQLLMDataLoaderDict():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
yield {'input_ids': torch.ones([1, 512], dtype=torch.long), 'attention_mask': torch.ones([1, 512], dtype=torch.long)}

dataloader = GPTQLLMDataLoader()
dataloader_list = GPTQLLMDataLoaderList()
dataloader_dict = GPTQLLMDataLoaderDict()

conf = PostTrainingQuantConfig(
approach='weight_only',
Expand Down Expand Up @@ -431,7 +439,7 @@ def __iter__(self):
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

# case 2: list or tuple
# # case 2: list or tuple
model_2 = copy.deepcopy(self.gptj)
input = torch.ones([1, 512], dtype=torch.long)
q_model = quantization.fit(model_2, conf, calib_dataloader=dataloader_list,)
Expand All @@ -441,8 +449,111 @@ def __iter__(self):
out2 = compressed_model(input)
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

# # case 2: list or tuple
model_3 = copy.deepcopy(self.gptj)
input = torch.ones([1, 512], dtype=torch.long)
q_model = quantization.fit(model_3, conf, calib_dataloader=dataloader_dict,)
q_model.save('saved')
out1 = q_model.model(input)
compressed_model = q_model.export_compressed_model()
out2 = compressed_model(input)
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

print("GPTQ with fixed length Done")

def test_GPTQ_unfixed_length_quant(self):
import random
class GPTQLLMDataLoader():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
length = random.randint(1, 1024)
yield torch.ones([1, length], dtype=torch.long)

class GPTQLLMDataLoaderList():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
length = random.randint(1, 1024)
yield (torch.ones([1, length], dtype=torch.long), torch.ones([1, length], dtype=torch.long))

class GPTQLLMDataLoaderDict():
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
length = random.randint(1, 1024)
yield {'input_ids': torch.ones([1, length], dtype=torch.long), 'attention_mask': torch.ones([1, length], dtype=torch.long)}

dataloader = GPTQLLMDataLoader()
dataloader_list = GPTQLLMDataLoaderList()
dataloader_dict = GPTQLLMDataLoaderDict()

conf = PostTrainingQuantConfig(
approach='weight_only',
op_type_dict={
'.*':{ # re.match
"weight": {
'bits': 4, # 1-8 bits
'group_size': 8, # -1 (per-channel)
'scheme': 'sym',
'algorithm': 'GPTQ',
},
},
},
op_name_dict={
'.*lm_head':{ # re.match
"weight": {
'dtype': 'fp32'
},
},
},
recipes={
'gptq_args':{'percdamp': 0.01, 'act_order': False, 'use_max_length': True},
},
)

# case 1: tensor
model_1 = copy.deepcopy(self.gptj)
input = torch.ones([1, 512], dtype=torch.long)
q_model = quantization.fit(model_1, conf, calib_dataloader=dataloader,)
q_model.save('saved')
out1 = q_model.model(input)
compressed_model = q_model.export_compressed_model()
out2 = compressed_model(input)
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

# # case 2: list or tuple
model_2 = copy.deepcopy(self.gptj)
input = torch.ones([1, 512], dtype=torch.long)
q_model = quantization.fit(model_2, conf, calib_dataloader=dataloader_list,)
q_model.save('saved')
out1 = q_model.model(input)
compressed_model = q_model.export_compressed_model()
out2 = compressed_model(input)
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

# # case 2: list or tuple
model_3 = copy.deepcopy(self.gptj)
input = torch.ones([1, 512], dtype=torch.long)
q_model = quantization.fit(model_3, conf, calib_dataloader=dataloader_dict,)
q_model.save('saved')
out1 = q_model.model(input)
compressed_model = q_model.export_compressed_model()
out2 = compressed_model(input)
torch.save(compressed_model.state_dict(), 'saved/compressed_model.pt')
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05))

print("GPTQ Done")
print("GPTQ with fixed length Done")

def test_TEQ_quant(self):
class teq_inc_loader(object):
Expand Down

0 comments on commit 3018319

Please sign in to comment.