From 4d4c2957d3f446354e7b1585763eb5d57d265aa1 Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 <109071285+WeiweiZhang1@users.noreply.github.com> Date: Mon, 3 Jul 2023 14:11:13 +0800 Subject: [PATCH] Adding jit trace to llm example (#1046) Signed-off-by: Zhang, Weiwei1 --- .../pruning/eager/requirements.txt | 2 + .../pruning/eager/run_clm_no_trainer.py | 140 +++++++++++------- 2 files changed, 88 insertions(+), 54 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt index a76f434a167..b3dacc5de81 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt @@ -4,3 +4,5 @@ sentencepiece transformers torch tqdm +cupy +optimum diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py index ea425de4204..7cb301f88e1 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py @@ -20,7 +20,10 @@ https://huggingface.co/models?filter=text-generation """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - +from accelerate.utils import set_seed +set_seed(42) +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger import argparse import json import logging @@ -29,16 +32,13 @@ import sys sys.path.insert(0, './neural-compressor') sys.path.insert(0, './') - import random from itertools import chain from pathlib import Path import datasets import torch -from accelerate import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import set_seed +torch.use_deterministic_algorithms(True, warn_only=True) from datasets import load_dataset from huggingface_hub import Repository, create_repo from torch.utils.data import DataLoader @@ -55,7 +55,6 @@ SchedulerType, default_data_collator, get_scheduler, - T5ForConditionalGeneration ) from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils.versions import require_version @@ -64,7 +63,7 @@ from timers import CPUTimer, GPUTimer from neural_compressor.compression.pruner import model_slim from neural_compressor.compression.pruner import parse_auto_slim_config -set_seed(42) + # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.23.0.dev0") @@ -97,8 +96,6 @@ def evaluate(self, model): step = 0 for input_ids, label, label_indices in tqdm(self.dataloader): with torch.no_grad(): - # if step == 0: - # model = torch.jit.trace(model, input_ids) step += 1 # timing if step > warmup_steps: my_timer.__enter__() @@ -169,6 +166,49 @@ def __iter__(self): def __len__(self): return self.length + + +class Net(torch.nn.Module): + def __init__(self, ori_model): + super(Net, self).__init__() + self.model = ori_model + def forward(self, input_ids, pastkv, mask): + return self.model(input_ids=input_ids, attention_mask=mask, past_key_values=pastkv, return_dict=False) + +def trace_model(model, tokenizer): + from optimum.utils import NormalizedConfigManager + normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config) + num_layers = normalized_config.num_layers + num_attention_heads = normalized_config.num_attention_heads + hidden_size = normalized_config.hidden_size + d_k = hidden_size // num_attention_heads + model_type = model.config.model_type + model = model.cpu() + model.eval() + prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \ + " She wanted to go to places and meet new people, and have fun." + init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0] + traced_model = None + if 'llama' in model_type: + input_ids = init_input_ids.clone() + attention_mask = torch.ones(len(input_ids)+1) + attention_mask[0] = 0 + input_ids = input_ids[0:1].unsqueeze(0) + attention_mask = attention_mask.unsqueeze(0) + past_key_value = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)]) + if 'llama_13b' in model_type: + past_key_value = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)]) + net = model + traced_model = torch.jit.trace(net, (input_ids, attention_mask, past_key_value)) + else: + input_ids = init_input_ids.clone().unsqueeze(0) + attention_mask = torch.ones(len(input_ids)).unsqueeze(0) + past_key_value = tuple([(torch.zeros([1,num_attention_heads,0,d_k]), + torch.zeros([1,num_attention_heads,0,d_k])) for i in range(num_layers)]) + net = Net(model) + traced_model = torch.jit.trace(net, (input_ids, past_key_value, attention_mask)) + return traced_model + def parse_args(): parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") @@ -372,6 +412,10 @@ def parse_args(): "--auto_slim", action="store_true", help="Whether or not to auto slim the model after pruning." ) + parser.add_argument( + "--auto_config", action="store_true", + help="Whether to automatically generate pruning configs." + ) parser.add_argument( "--max_length", type=int, default=2048, @@ -428,8 +472,8 @@ def main(): transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) + # if args.seed is not None: # Already set at the beginning of the file + # set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: @@ -517,13 +561,25 @@ def main(): config = CONFIG_MAPPING[args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") - is_llama = bool("llama" in args.model_name_or_path) - is_t5 = bool("t5" in args.model_name_or_path) + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=args.low_cpu_mem_usage, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + model_name = model.config.model_type + if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) elif args.model_name_or_path: - if is_llama: - tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path) + if 'llama' in model_name: + from transformers import LlamaTokenizer + tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path) else : tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) else: @@ -532,23 +588,7 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) - if args.model_name_or_path: - if is_t5: - model = T5ForConditionalGeneration.from_pretrained( - args.model_name_or_path, - config=config, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - low_cpu_mem_usage=args.low_cpu_mem_usage, - ) - - else: - logger.info("Training new model from scratch") - model = AutoModelForCausalLM.from_config(config) + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -753,12 +793,12 @@ def group_texts(examples): pruning_start = num_iterations * args.num_train_epochs + 1 pruning_end = pruning_start - if not args.auto_slim: + if not args.auto_config: pruning_configs=[ { "pruning_type": "retrain_free", "pruning_scope": "global", - "op_names": ["wo"], #for t5 + "op_names": ['.fc', '.mlp'], "excluded_op_names": [".attn"], "sparsity_decay_type": "exp", "pattern": "channelx1", @@ -767,21 +807,21 @@ def group_texts(examples): } ] else: - # auto slim config + # auto config pruning_configs=[] - auto_slim_configs = parse_auto_slim_config( + auto_configs = parse_auto_slim_config( model, ffn2_sparsity = args.target_sparsity, mha_sparsity = 0, pruning_scope = "global", pruning_type = "retrain_free", ) - pruning_configs += auto_slim_configs + pruning_configs += auto_configs configs = WeightPruningConfig( pruning_configs, target_sparsity=args.target_sparsity, - # pattern=args.pruning_pattern, + pattern=args.pruning_pattern, pruning_frequency=frequency, start_step=pruning_start, end_step=pruning_end, @@ -844,6 +884,7 @@ def group_texts(examples): dataset_eval = raw_datasets["validation"] dataset_eval = dataset_eval.shuffle(seed=42) evaluator = Evaluator(dataset_eval, tokenizer, model.device, batch_size=args.per_device_eval_batch_size) + def eval_func(model): acc, avg_latency = evaluator.evaluate(model) return acc, avg_latency @@ -873,26 +914,17 @@ def eval_func(model): logger.info(f"***** Running Evaluation after ffn auto_slim*****") accuracy, avg_latency = eval_func(model) logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}") + + if args.output_dir is not None: + accelerator.wait_for_everyone() + traced_model = trace_model(model, tokenizer) + logger.info(f"Save silmed jit model") + torch.jit.save(traced_model, args.output_dir+"/slimed_jit_model.pt") + if args.with_tracking: accelerator.end_training() - if args.output_dir is not None and args.auto_slim: - accelerator.wait_for_everyone() - # unwrapped_model = accelerator.unwrap_model(model) - # unwrapped_model.save_pretrained( - # args.output_dir+"/slimed", is_main_process=accelerator.is_main_process, save_function=accelerator.save - # ) - model.to('cpu') - torch.save(model, args.output_dir+"/slimed_model.pt") - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of auto slim", auto_lfs_prune=True) - - # with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: - # json.dump({"perplexity": perplexity}, f) - if __name__ == "__main__": main()