From eabd0450eb80cfcc9f14dfbee4d415c6aa4ec241 Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Tue, 18 Jul 2023 13:47:51 +0800 Subject: [PATCH 1/6] Add gptq examples Signed-off-by: YIYANGCAI --- .../ptq_weight_only/gptj-for-gptq/README.md | 10 + .../gptj-for-gptq/cnn_dm_dataset.py | 172 +++++++++++++ .../ptq_weight_only/gptj-for-gptq/main.py | 238 ++++++++++++++++++ .../gptj-for-gptq/run-gptq-gptj-sym.sh | 14 ++ 4 files changed, 434 insertions(+) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md new file mode 100644 index 00000000000..2e3dc852ed0 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md @@ -0,0 +1,10 @@ +# Run GPTQ tasks on GPT-j-6B model for summary task + +# Step by Step + +## Step 1 Prepare datasets and models +Use the following link to get +[**CNN Daily Mail** datasets](https://github.com/intel-innersource/frameworks.ai.benchmarking.mlperf.submission.inference-submission-v3-1/tree/master/closed/Intel/code/gpt-j/pytorch-cpu#download-and-prepare-dataset) +and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model) + +## Step 2 Run GPTQ quantization diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py new file mode 100644 index 00000000000..63ada49fffe --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py @@ -0,0 +1,172 @@ +import sys +import argparse +import os +import time +import json +import fnmatch + +import copy +import logging +from dataclasses import dataclass, field +from typing import Optional, Dict, Sequence + + +import numpy as np +import torch +import torch.nn.functional as F +from datasets import load_dataset, load_from_disk +from torch.nn.functional import pad +from torch.utils.data import DataLoader +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +import random +random.seed(9973) + +# Bucketize sequence lengths +MaxLens = range(0,64,1919) +Buckets = dict() +cutoff_step = 64 +min_cutoff = 64 +min_len = 1 +for cutoff in range(min_cutoff, 1921, cutoff_step): # All input sequences + Buckets[cutoff] = list(range(min_len, cutoff, 1)) + min_len = cutoff + +#Buckets[1920] = list(range(min_len, 1921, 1)) + +input_buckets = dict() +for cutoff, seq_lens in Buckets.items(): + for seq_len in seq_lens: + input_buckets[seq_len] = cutoff + +#print("Buckets: {}".format(input_buckets)) + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" +PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} + + +class CNNDAILYMAIL(object): + def __init__(self, model_path, data_path, device="cpu",is_calib=False, num_samples=20, max_len=1920): + self.model_path = model_path + self.data_path = data_path + self.device = device + self.num_samples = num_samples + self.is_calib = is_calib + + self.padding = "max_length" if self.is_calib else False + self.max_len = 2048 if self.is_calib else max_len + + self.calib_collator = self.collate_batch + self.pad_max = max_len + self.load_tokenizer() + self.load_dataset() + def load_dataset(self): + """ Loads dataset""" + with open(self.data_path, "r") as fid: + list_data_dict = json.load(fid) + self.list_data_dict = copy.deepcopy(list_data_dict) + + if self.num_samples is not None: + self.num_samples = min(self.num_samples, len(list_data_dict)) + + if self.is_calib: + list_data_dict = list_data_dict[:self.num_samples] + else: + list_data_dict = random.choices(list_data_dict, k=self.num_samples) + + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [prompt_input.format_map(example) for example in list_data_dict] + targets = [f"{example['output']}" for example in list_data_dict] + + self.input_ids = [] + self.input_lens = [] + for i in range(len(sources)): + tok_input = self.tokenize_function(sources[i]) + self.input_ids.append(tok_input.input_ids) + + + #if self.num_samples is not None: + # self.num_samples = min(self.num_samples, len(list_data_dict)) + # self.input_ids = random.choices(self.input_ids, k=self.num_samples) + # print("Sources: {}".format(len(sources))) + # print("Targets: {}".format(len(targets))) + # sources = random.choices(sources, k=self.num_samples) + # targets = random.choices(targets, k=self.num_samples) + + + self.sources = sources + self.targets = targets + + def load_tokenizer(self): + """ Returns the tokenizer """ + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + model_max_length=2048, + padding_side="right", + use_fast=False, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @torch.no_grad() + def tokenize_function(self, text): + example = self.tokenizer(text, truncation=True, max_length=self.max_len, return_tensors="pt", padding=self.padding) + return example + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + input_ids = self.input_ids[i] + input_len = input_ids.shape[-1] + #pad_size = input_buckets[input_len] - input_len + #input_ids = F.pad(input_ids, pad=(0, pad_size)) + return (input_ids, input_len) + + @torch.no_grad() + def collate_batch(self, batch): + input_ids_padded = [] + + for input_ids, input_lens in batch: # input_ids are returned by this dataset (see __getitem__) + pad_len = self.pad_max - input_ids.shape[0] + #input_ids = F.pad(input_ids, pad=(0, pad_size), value=self.tokenizer.pad_token_id) + input_ids_padded.append(input_ids) + + input_ids_padded = torch.vstack(input_ids_padded) + return (input_ids_padded, input_ids_padded) + + def get_warmup_samples(self): + cutoff_set = set(range(128, 1920, 64)) + warmup_samples = [] + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [prompt_input.format_map(example) for example in self.list_data_dict] + for source in sources: #self.input_ids: + tok_input = self.tokenize_function(source) + input_ids = tok_input.input_ids + input_len = input_ids.shape[-1] + bucket = input_buckets[input_len] + if bucket in cutoff_set: + #print("inputlen: {}; Bucket: {}".format(input_len, bucket)) + pad_size = bucket - input_len + input_ids = F.pad(input_ids, pad=(0, pad_size), value=0) + warmup_samples.append(input_ids) + cutoff_set.remove(bucket) + if len(cutoff_set)==0: + break + + return warmup_samples diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py new file mode 100644 index 00000000000..0670284cbcb --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py @@ -0,0 +1,238 @@ +import sys +sys.path.append("./") +import math +import time +import numpy as np +import torch +import torch.nn as nn +import transformers + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from cnn_dm_dataset import CNNDAILYMAIL +from torch.utils.data import DataLoader + +from tqdm import tqdm + +import evaluate +import nltk +nltk.download("punkt", quiet=False) +metric = evaluate.load("rouge") +import rouge_score + +def get_gptj(model): + + def skip(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import GPTJForCausalLM, AutoModelForCausalLM + model = GPTJForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16) + #model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16) + model.seqlen = 2048 + return model + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + +def benchmark(model, benchmark_dataset, tokenizer, sources, targets, check=False): + #input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + #torch.cuda.synchronize() + + cache = {'past': None} + + def clear_past(i): + + def tmp(layer, inp, out): + if cache['past']: + cache['past'][i] = None + + return tmp + + for i, layer in enumerate(model.transformer.h): + layer.register_forward_hook(clear_past(i)) + + print('Benchmarking ...') + + if check: + loss = nn.CrossEntropyLoss() + tot = 0. + + def sync(): + if hasattr(model, 'gpus'): + for gpu in model.gpus: + torch.cuda.synchronize(gpu) + else: + torch.cuda.synchronize() + + max_memory = 0 + generate_kwargs = { + "early_stopping": True, + "max_new_tokens": 128, + "min_new_tokens": 30, + "num_beams": 4, + } + + #preds = [] + batch_targets = [] + predictions = [] + ground_truths = [] + + with torch.no_grad(), torch.inference_mode(): + times = [] + #for i, (input_ids, labels) in enumerate(benchmark_dataset):# in range(input_ids.numel()): + for i in tqdm(range(len(sources))): #tqdm(range(len(sources))): + input_ids, input_lens = benchmark_dataset[i] + input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + + #input_lens = input_ids.shape[-1] + attention_mask = torch.ones((1, input_ids.numel()), device=DEV) + + tick = time.time() + out = model.generate(input_ids, **generate_kwargs, pad_token_id=tokenizer.pad_token_id, ) + + sync() + times.append(time.time() - tick) + + out_tokens = out.cpu().numpy() #[:,input_len:] + #print("Iter {}".format(i)) + #print("Input len: {}".format(input_lens)); + #print("Output len: {}".format(out_tokens.shape[-1] - input_ids.shape[-1])) + print("Inference time: {}".format(round(times[-1],3))) + + pred = out_tokens[:, input_lens:] + pred_batch = tokenizer.batch_decode(pred, skip_special_tokens=True) + targ_batch = targets[i:i+1] + preds, targs = postprocess_text(pred_batch, targ_batch) + predictions.extend(preds) + ground_truths.extend(targs) + + #cache['past'] = list(out.past_key_values) + del out + #sync() + print('Median:', np.median(times)) + + print("Predictions: {}".format(len(predictions))) + print("References: {}".format(len(ground_truths))) + result = metric.compute(predictions=predictions, references=ground_truths, use_stemmer=True, use_aggregator=False) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in predictions] + result["gen_len"] = np.sum(prediction_lens) + result["gen_num"] = len(predictions) + print(result) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + '--model_name_or_path', type=str, + help='BLOOM model to load; pass `bigscience/bloom-X`.' + ) + parser.add_argument( + '--dataset', type=str, choices=['wikitext2', 'ptb', 'c4', 'pile'], + help='Where to extract calibration data from.' + ) + parser.add_argument( + '--seed', + type=int, default=0, help='Seed for sampling the calibration data.' + ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) + parser.add_argument( + '--percdamp', type=float, default=.01, + help='Percent of the average Hessian diagonal to use for dampening.' + ) + parser.add_argument( + '--nearest', action='store_true', + help='Whether to run the RTN baseline.' + ) + parser.add_argument( + '--wbits', type=int, default=16, choices=[2, 3, 4, 16], + help='#bits to use for quantization; use 16 for evaluating base model.' + ) + parser.add_argument( + '--group_size', type=int, default=-1, + help='Groupsize to use for quantization; default uses full row.' + ) + parser.add_argument( + '--sym', action='store_true', + help='Whether to perform symmetric quantization.' + ) + parser.add_argument( + '--new-eval', action='store_true', + help='Whether to use the new PTB and C4 eval' + ) + parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') + parser.add_argument('--calib-data-path', type=str, help="Path to calibration json file") + parser.add_argument('--val-data-path', type=str, help="Path to validation json file") + parser.add_argument('--calib-iters', type=int, default=128, help="Number of samples for calibration") + parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') + + args = parser.parse_args() + # import pdb;pdb.set_trace() + # method 1: directly import AutoModelForCausalLM + model = get_gptj(args.model_name_or_path) + model.eval() + + # import pdb;pdb.set_trace() + calib_dataset = CNNDAILYMAIL(args.model_name_or_path, args.calib_data_path, is_calib=True, num_samples=args.calib_iters) + dataloader=DataLoader(calib_dataset, + batch_size=1, + shuffle=False, + collate_fn=calib_dataset.collate_batch + ) + + DEV = torch.device('cuda:0') + + # do the quantization + print('Starting ...') + weight_config = { + 'wbits': args.wbits, + 'group_size': args.group_size, + 'sym': args.sym, + 'percdamp': args.percdamp, + 'actorder': args.act_order + } + print(weight_config) + from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize + quantizers = gptq_quantize(model, weight_config=weight_config, dataloader=dataloader, device = DEV) + + import pdb;pdb.set_trace() + + # benchmarking first 100 examples + # if args.benchmark: + if True: + # use half to accerlerate inference + model.half() + gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + gptj_multigpu(model, gpus, gpu_dist) + else: + model = model.to(DEV) + + val_dataset = CNNDAILYMAIL(args.model_name_or_path, + args.val_data_path,is_calib=False, + num_samples=None) + + tokenizer = val_dataset.tokenizer + sources = val_dataset.sources + targets = val_dataset.targets + benchmark_set = DataLoader(val_dataset, + batch_size=1, + shuffle=False, + collate_fn=val_dataset.collate_batch + ) + + benchmark(model, val_dataset, tokenizer, sources, targets, check=args.check) + print("Done") \ No newline at end of file diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh new file mode 100644 index 00000000000..8f4c2300ed7 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh @@ -0,0 +1,14 @@ +CALIBRATION_DATA=/path/to/your/data/calibration-data/cnn_dailymail_calibration.json +VALIDATION_DATA=/path/to/your/data/validation-data/cnn_dailymail_validation.json +MODEL_DIR=/data4/cyy/gptq_inc/mlperf/gpt-j-mlperf/finetuned-gptj/ + +python -u main.py \ + --model_name_or_path ${MODEL_DIR} \ + --wbits 4 \ + --act-order \ + --sym \ + --group_size 128 \ + --nsamples 128 \ + --calib-data-path ${CALIBRATION_DATA} \ + --val-data-path ${VALIDATION_DATA} \ + --calib-iters 128 \ From 48a6b2659c7bbca4e33c3ac6b41ae824453e462a Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Tue, 18 Jul 2023 13:58:49 +0800 Subject: [PATCH 2/6] update README.md Signed-off-by: YIYANGCAI --- .../quantization/ptq_weight_only/gptj-for-gptq/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md index 2e3dc852ed0..d60fad3d7d4 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md @@ -8,3 +8,6 @@ Use the following link to get and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model) ## Step 2 Run GPTQ quantization +```shell +sh run-gptq-gptj-sym.sh +``` From 9ce4fe5480aef0a51bd25bfa42de62b9f6982fba Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Tue, 18 Jul 2023 14:52:58 +0800 Subject: [PATCH 3/6] modify script Signed-off-by: YIYANGCAI --- .../ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh index 8f4c2300ed7..ee7084dbdda 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh @@ -1,6 +1,6 @@ CALIBRATION_DATA=/path/to/your/data/calibration-data/cnn_dailymail_calibration.json VALIDATION_DATA=/path/to/your/data/validation-data/cnn_dailymail_validation.json -MODEL_DIR=/data4/cyy/gptq_inc/mlperf/gpt-j-mlperf/finetuned-gptj/ +MODEL_DIR=/path/to/finetuned-gptj/ python -u main.py \ --model_name_or_path ${MODEL_DIR} \ From 96dc1de0be52bf9a7dd7bc817b4ae22ae257524e Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Tue, 18 Jul 2023 18:08:02 +0800 Subject: [PATCH 4/6] update examples Signed-off-by: YIYANGCAI --- .../quantization/ptq_weight_only/README.md | 12 +++++++++++- .../{gptj-for-gptq => }/cnn_dm_dataset.py | 0 .../ptq_weight_only/gptj-for-gptq/README.md | 13 ------------- .../main.py => run_gptj_mlperf_int4.py} | 0 ...run-gptq-gptj-sym.sh => run_gptj_mlperf_int4.sh} | 2 +- 5 files changed, 12 insertions(+), 15 deletions(-) rename examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/{gptj-for-gptq => }/cnn_dm_dataset.py (100%) delete mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md rename examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/{gptj-for-gptq/main.py => run_gptj_mlperf_int4.py} (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/{gptj-for-gptq/run-gptq-gptj-sym.sh => run_gptj_mlperf_int4.sh} (92%) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md index 35b338f9ca8..d4798e9794a 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md @@ -41,6 +41,16 @@ sh run_tuning.sh --topology=topology_name --input_model=model_name_or_path --wei > > `weight_only_bits`, `weight_only_group`, `weight_only_scheme`, and `weight_only_algorithm` can be modified by user. For details, please refer to [README](../../../../../../../docs/source/quantization_weight_only.md). +### Run MLPerf on GPT-J-6B +Use the following link to get +[**CNN Daily Mail** datasets](https://github.com/intel-innersource/frameworks.ai.benchmarking.mlperf.submission.inference-submission-v3-1/tree/master/closed/Intel/code/gpt-j/pytorch-cpu#download-and-prepare-dataset) +and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model) + +Then run following command to do quantization +```shell +sh run_gptj_mlperf_int4.sh +``` + ## 2. Benchmark ```bash # int8 @@ -102,4 +112,4 @@ from neural_compressor.utils.pytorch import load quantized_model = load(tuned_checkpoint, model) ``` -------- -For more details, please refer to the [sample code](./run_clm.py). +For more details, please refer to the [sample code](./run_clm.py). \ No newline at end of file diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/cnn_dm_dataset.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/cnn_dm_dataset.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/cnn_dm_dataset.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md deleted file mode 100644 index d60fad3d7d4..00000000000 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Run GPTQ tasks on GPT-j-6B model for summary task - -# Step by Step - -## Step 1 Prepare datasets and models -Use the following link to get -[**CNN Daily Mail** datasets](https://github.com/intel-innersource/frameworks.ai.benchmarking.mlperf.submission.inference-submission-v3-1/tree/master/closed/Intel/code/gpt-j/pytorch-cpu#download-and-prepare-dataset) -and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model) - -## Step 2 Run GPTQ quantization -```shell -sh run-gptq-gptj-sym.sh -``` diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/main.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh similarity index 92% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh rename to examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh index ee7084dbdda..6ef44071212 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/gptj-for-gptq/run-gptq-gptj-sym.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh @@ -2,7 +2,7 @@ CALIBRATION_DATA=/path/to/your/data/calibration-data/cnn_dailymail_calibration.j VALIDATION_DATA=/path/to/your/data/validation-data/cnn_dailymail_validation.json MODEL_DIR=/path/to/finetuned-gptj/ -python -u main.py \ +python -u run_gptj_mlperf_int4.py \ --model_name_or_path ${MODEL_DIR} \ --wbits 4 \ --act-order \ From f7c77b3c03a2d3bab570e0134d5d36e4f6ef2f32 Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Wed, 19 Jul 2023 22:39:13 +0800 Subject: [PATCH 5/6] remove nsample args to avoid vague meaning. Signed-off-by: YIYANGCAI --- .../quantization/ptq_weight_only/run_gptj_mlperf_int4.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py index 0670284cbcb..1300ae6a8a1 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py @@ -145,10 +145,6 @@ def sync(): '--seed', type=int, default=0, help='Seed for sampling the calibration data.' ) - parser.add_argument( - '--nsamples', type=int, default=128, - help='Number of calibration data samples.' - ) parser.add_argument( '--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.' @@ -235,4 +231,4 @@ def sync(): ) benchmark(model, val_dataset, tokenizer, sources, targets, check=args.check) - print("Done") \ No newline at end of file + print("Done") From 361695bea608fbbd93dccd468f99d4d71b369fdf Mon Sep 17 00:00:00 2001 From: YIYANGCAI Date: Fri, 4 Aug 2023 10:43:21 +0800 Subject: [PATCH 6/6] support option of converting to fp16 before gptq. Signed-off-by: YIYANGCAI --- .../ptq_weight_only/run_gptj_mlperf_int4.py | 152 +++++++++++++++--- .../ptq_weight_only/run_gptj_mlperf_int4.sh | 10 +- 2 files changed, 131 insertions(+), 31 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py index 1300ae6a8a1..7b52fee1b69 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py @@ -9,6 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer +# from cnn_daily_loader_wenhua import CNNDAILYMAIL from cnn_dm_dataset import CNNDAILYMAIL from torch.utils.data import DataLoader @@ -20,6 +21,8 @@ metric = evaluate.load("rouge") import rouge_score +from neural_compressor import quantization, PostTrainingQuantConfig + def get_gptj(model): def skip(*args, **kwargs): @@ -45,6 +48,9 @@ def postprocess_text(preds, targets): def benchmark(model, benchmark_dataset, tokenizer, sources, targets, check=False): #input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) #torch.cuda.synchronize() + # for idx in range(len(targets)): + # if idx >= 5: break + # print(targets[idx]) cache = {'past': None} @@ -85,6 +91,7 @@ def sync(): predictions = [] ground_truths = [] + # import pdb;pdb.set_trace() with torch.no_grad(), torch.inference_mode(): times = [] #for i, (input_ids, labels) in enumerate(benchmark_dataset):# in range(input_ids.numel()): @@ -92,6 +99,8 @@ def sync(): input_ids, input_lens = benchmark_dataset[i] input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + print(input_ids) + #input_lens = input_ids.shape[-1] attention_mask = torch.ones((1, input_ids.numel()), device=DEV) @@ -111,6 +120,7 @@ def sync(): pred_batch = tokenizer.batch_decode(pred, skip_special_tokens=True) targ_batch = targets[i:i+1] preds, targs = postprocess_text(pred_batch, targ_batch) + print(f"====={targs}=====\n") predictions.extend(preds) ground_truths.extend(targs) @@ -128,6 +138,67 @@ def sync(): result["gen_num"] = len(predictions) print(result) +def gptj_multigpu(model, gpus, gpu_dist): + #model.model.embed_tokens = model.model.embed_tokens.to(gpus[0]) + model.transformer.wte = model.transformer.wte.to(gpus[0]) + #if hasattr(model.model, 'norm') and model.model.norm: + # model.model.norm = model.model.norm.to(gpus[0]) + + if hasattr(model.transformer, 'ln_f') and model.transformer.ln_f: + model.transformer.ln_f = model.transformer.ln_f.to(gpus[0]) + import copy + model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0]) + + cache = {'mask': None, 'position_ids': None} + + class MoveModule(nn.Module): + + def __init__(self, module, invalidate_cache): + super().__init__() + self.module = module + self.dev = next(iter(self.module.parameters())).device + self.invalidate_cache=invalidate_cache + + def forward(self, *inp, **kwargs): + inp = list(inp) + if inp[0].device != self.dev: + inp[0] = inp[0].to(self.dev) + + if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache: + cache['mask'] = kwargs['attention_mask'].to(self.dev) + kwargs['attention_mask'] = cache['mask'] + + if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache: + cache['position_ids'] = kwargs['position_ids'].to(self.dev) + kwargs['position_ids'] = cache['position_ids'] + + tmp = self.module(*inp, **kwargs) + return tmp + + #layers = model.model.layers + layers = model.transformer.h + from math import ceil + if not gpu_dist: + pergpu = ceil(len(layers) / len(gpus)) + for i in range(len(layers)): + layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0) + else: + assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0." + assigned_gpus = [0] * (gpu_dist[0]-1) + for i in range(1, len(gpu_dist)): + assigned_gpus = assigned_gpus + [i] * gpu_dist[i] + + remaining_assignments = len(layers)-len(assigned_gpus) - 1 + if remaining_assignments > 0: + assigned_gpus = assigned_gpus + [-1] * remaining_assignments + + assigned_gpus = assigned_gpus + [0] + + for i in range(len(layers)): + layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0) + + model.gpus = gpus + if __name__ == '__main__': import argparse @@ -145,6 +216,10 @@ def sync(): '--seed', type=int, default=0, help='Seed for sampling the calibration data.' ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) parser.add_argument( '--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.' @@ -174,14 +249,25 @@ def sync(): parser.add_argument('--val-data-path', type=str, help="Path to validation json file") parser.add_argument('--calib-iters', type=int, default=128, help="Number of samples for calibration") parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') + parser.add_argument('--use_fp16', action='store_true', help='Whether to convert model to fp16 before using GPTQ.') + parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU.') + # load the gptj model args = parser.parse_args() - # import pdb;pdb.set_trace() # method 1: directly import AutoModelForCausalLM model = get_gptj(args.model_name_or_path) model.eval() - # import pdb;pdb.set_trace() + if args.use_gpu and torch.cuda.is_available(): + DEV = torch.device('cuda:0') + else: + DEV = torch.device('cpu') + + if args.use_fp16: + model.half() + model = model.to(DEV) + + # load the dataset calib_dataset = CNNDAILYMAIL(args.model_name_or_path, args.calib_data_path, is_calib=True, num_samples=args.calib_iters) dataloader=DataLoader(calib_dataset, batch_size=1, @@ -189,37 +275,51 @@ def sync(): collate_fn=calib_dataset.collate_batch ) - DEV = torch.device('cuda:0') - - # do the quantization + # # do the quantization print('Starting ...') - weight_config = { - 'wbits': args.wbits, - 'group_size': args.group_size, - 'sym': args.sym, - 'percdamp': args.percdamp, - 'actorder': args.act_order - } - print(weight_config) - from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize - quantizers = gptq_quantize(model, weight_config=weight_config, dataloader=dataloader, device = DEV) - - import pdb;pdb.set_trace() + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 128, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'GPTQ', + }, + }, + }, + op_name_dict={ + '.*lm_head':{ # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + }, + recipes={ + 'gptq_args':{'percdamp': 0.01, 'actorder':args.act_order}, + }, + ) + + q_model = quantization.fit(model, conf, calib_dataloader=dataloader,) + + q_model.save("./gptj-gptq-gs128-calib128-calibration-fp16/") + # q_model.float() + # q_model.save("./gptj-gptq-gs128-calib128-calibration-fp32/") # benchmarking first 100 examples # if args.benchmark: if True: # use half to accerlerate inference model.half() - gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - gptj_multigpu(model, gpus, gpu_dist) - else: - model = model.to(DEV) + model = model.to(DEV) - val_dataset = CNNDAILYMAIL(args.model_name_or_path, - args.val_data_path,is_calib=False, - num_samples=None) + val_dataset = CNNDAILYMAIL( + args.model_name_or_path, + args.val_data_path, + #is_calib = True, + num_samples=None + ) tokenizer = val_dataset.tokenizer sources = val_dataset.sources @@ -227,7 +327,7 @@ def sync(): benchmark_set = DataLoader(val_dataset, batch_size=1, shuffle=False, - collate_fn=val_dataset.collate_batch + # collate_fn=val_dataset.collate_batch ) benchmark(model, val_dataset, tokenizer, sources, targets, check=args.check) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh index 6ef44071212..11440c2d182 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh @@ -1,14 +1,14 @@ -CALIBRATION_DATA=/path/to/your/data/calibration-data/cnn_dailymail_calibration.json -VALIDATION_DATA=/path/to/your/data/validation-data/cnn_dailymail_validation.json -MODEL_DIR=/path/to/finetuned-gptj/ +CALIBRATION_DATA=/your/data/calibration-data/cnn_dailymail_calibration.json +VALIDATION_DATA=/your/data/validation-data/cnn_dailymail_validation.json +MODEL_DIR=/your/gptj/ -python -u run_gptj_mlperf_int4.py \ +python -u examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py \ --model_name_or_path ${MODEL_DIR} \ --wbits 4 \ - --act-order \ --sym \ --group_size 128 \ --nsamples 128 \ --calib-data-path ${CALIBRATION_DATA} \ --val-data-path ${VALIDATION_DATA} \ --calib-iters 128 \ + --use_fp16 \ No newline at end of file