Skip to content

Commit

Permalink
Adding jit trace to llm example (#1046)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 authored Jul 3, 2023
1 parent 20f9704 commit 4d4c295
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ sentencepiece
transformers
torch
tqdm
cupy
optimum
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

from accelerate.utils import set_seed
set_seed(42)
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
import argparse
import json
import logging
Expand All @@ -29,16 +32,13 @@
import sys
sys.path.insert(0, './neural-compressor')
sys.path.insert(0, './')

import random
from itertools import chain
from pathlib import Path

import datasets
import torch
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
torch.use_deterministic_algorithms(True, warn_only=True)
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
Expand All @@ -55,7 +55,6 @@
SchedulerType,
default_data_collator,
get_scheduler,
T5ForConditionalGeneration
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
Expand All @@ -64,7 +63,7 @@
from timers import CPUTimer, GPUTimer
from neural_compressor.compression.pruner import model_slim
from neural_compressor.compression.pruner import parse_auto_slim_config
set_seed(42)


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.23.0.dev0")
Expand Down Expand Up @@ -97,8 +96,6 @@ def evaluate(self, model):
step = 0
for input_ids, label, label_indices in tqdm(self.dataloader):
with torch.no_grad():
# if step == 0:
# model = torch.jit.trace(model, input_ids)
step += 1
# timing
if step > warmup_steps: my_timer.__enter__()
Expand Down Expand Up @@ -169,6 +166,49 @@ def __iter__(self):

def __len__(self):
return self.length


class Net(torch.nn.Module):
def __init__(self, ori_model):
super(Net, self).__init__()
self.model = ori_model
def forward(self, input_ids, pastkv, mask):
return self.model(input_ids=input_ids, attention_mask=mask, past_key_values=pastkv, return_dict=False)

def trace_model(model, tokenizer):
from optimum.utils import NormalizedConfigManager
normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config)
num_layers = normalized_config.num_layers
num_attention_heads = normalized_config.num_attention_heads
hidden_size = normalized_config.hidden_size
d_k = hidden_size // num_attention_heads
model_type = model.config.model_type
model = model.cpu()
model.eval()
prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
" She wanted to go to places and meet new people, and have fun."
init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
traced_model = None
if 'llama' in model_type:
input_ids = init_input_ids.clone()
attention_mask = torch.ones(len(input_ids)+1)
attention_mask[0] = 0
input_ids = input_ids[0:1].unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
past_key_value = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)])
if 'llama_13b' in model_type:
past_key_value = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)])
net = model
traced_model = torch.jit.trace(net, (input_ids, attention_mask, past_key_value))
else:
input_ids = init_input_ids.clone().unsqueeze(0)
attention_mask = torch.ones(len(input_ids)).unsqueeze(0)
past_key_value = tuple([(torch.zeros([1,num_attention_heads,0,d_k]),
torch.zeros([1,num_attention_heads,0,d_k])) for i in range(num_layers)])
net = Net(model)
traced_model = torch.jit.trace(net, (input_ids, past_key_value, attention_mask))
return traced_model


def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
Expand Down Expand Up @@ -372,6 +412,10 @@ def parse_args():
"--auto_slim", action="store_true",
help="Whether or not to auto slim the model after pruning."
)
parser.add_argument(
"--auto_config", action="store_true",
help="Whether to automatically generate pruning configs."
)
parser.add_argument(
"--max_length",
type=int, default=2048,
Expand Down Expand Up @@ -428,8 +472,8 @@ def main():
transformers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# if args.seed is not None: # Already set at the beginning of the file
# set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -517,13 +561,25 @@ def main():
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")

is_llama = bool("llama" in args.model_name_or_path)
is_t5 = bool("t5" in args.model_name_or_path)
if args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)

model_name = model.config.model_type

if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
if is_llama:
tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path)
if 'llama' in model_name:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
else :
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
else:
Expand All @@ -532,23 +588,7 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if args.model_name_or_path:
if is_t5:
model = T5ForConditionalGeneration.from_pretrained(
args.model_name_or_path,
config=config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)

else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)


# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
Expand Down Expand Up @@ -753,12 +793,12 @@ def group_texts(examples):
pruning_start = num_iterations * args.num_train_epochs + 1
pruning_end = pruning_start

if not args.auto_slim:
if not args.auto_config:
pruning_configs=[
{
"pruning_type": "retrain_free",
"pruning_scope": "global",
"op_names": ["wo"], #for t5
"op_names": ['.fc', '.mlp'],
"excluded_op_names": [".attn"],
"sparsity_decay_type": "exp",
"pattern": "channelx1",
Expand All @@ -767,21 +807,21 @@ def group_texts(examples):
}
]
else:
# auto slim config
# auto config
pruning_configs=[]
auto_slim_configs = parse_auto_slim_config(
auto_configs = parse_auto_slim_config(
model,
ffn2_sparsity = args.target_sparsity,
mha_sparsity = 0,
pruning_scope = "global",
pruning_type = "retrain_free",
)
pruning_configs += auto_slim_configs
pruning_configs += auto_configs

configs = WeightPruningConfig(
pruning_configs,
target_sparsity=args.target_sparsity,
# pattern=args.pruning_pattern,
pattern=args.pruning_pattern,
pruning_frequency=frequency,
start_step=pruning_start,
end_step=pruning_end,
Expand Down Expand Up @@ -844,6 +884,7 @@ def group_texts(examples):
dataset_eval = raw_datasets["validation"]
dataset_eval = dataset_eval.shuffle(seed=42)
evaluator = Evaluator(dataset_eval, tokenizer, model.device, batch_size=args.per_device_eval_batch_size)

def eval_func(model):
acc, avg_latency = evaluator.evaluate(model)
return acc, avg_latency
Expand Down Expand Up @@ -873,26 +914,17 @@ def eval_func(model):
logger.info(f"***** Running Evaluation after ffn auto_slim*****")
accuracy, avg_latency = eval_func(model)
logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}")

if args.output_dir is not None:
accelerator.wait_for_everyone()
traced_model = trace_model(model, tokenizer)
logger.info(f"Save silmed jit model")
torch.jit.save(traced_model, args.output_dir+"/slimed_jit_model.pt")


if args.with_tracking:
accelerator.end_training()

if args.output_dir is not None and args.auto_slim:
accelerator.wait_for_everyone()
# unwrapped_model = accelerator.unwrap_model(model)
# unwrapped_model.save_pretrained(
# args.output_dir+"/slimed", is_main_process=accelerator.is_main_process, save_function=accelerator.save
# )
model.to('cpu')
torch.save(model, args.output_dir+"/slimed_model.pt")
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of auto slim", auto_lfs_prune=True)

# with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
# json.dump({"perplexity": perplexity}, f)


if __name__ == "__main__":
main()

0 comments on commit 4d4c295

Please sign in to comment.