From 4ea5d7b98641c9266f22d284899ea80c2c6c4b9c Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Thu, 31 Aug 2023 10:58:47 +0800 Subject: [PATCH 1/7] support multi-cards magnitude pruning Signed-off-by: Lv, Kaokao --- .../pruning/eager/ds_config.json | 40 + .../language-modeling/pruning/eager/run.sh | 13 + .../eager/run_clm_no_trainer_deepspeed.py | 812 +++++++++++++++++ .../eager/run_clm_no_trainer_pruning.py | 839 ++++++++++++++++++ .../language-modeling/pruning/eager/run_ds.sh | 21 + 5 files changed, 1725 insertions(+) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json new file mode 100644 index 00000000000..2d2ae38da78 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json @@ -0,0 +1,40 @@ +{ + "train_batch_size": 64, + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0.0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto", + "warmup_type": "cosine" + } + } +} diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh new file mode 100644 index 00000000000..ae50b88f9f4 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh @@ -0,0 +1,13 @@ +export CUDA_VISIBLE_DEVICES=0 +python run_clm_no_trainer_pruning.py \ + --dataset_name ./pile-10k \ + --model_name_or_path /models/opt-125m/ \ + --block_size 128 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 4 \ + --output_dir /tmp/test-clm \ + --do_prune \ + --num_train_epochs 10 \ + --target_sparsity 0.8 \ + --pruning_pattern "4x1" \ + --pruning_frequency 1000 diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py new file mode 100644 index 00000000000..996a8163341 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py @@ -0,0 +1,812 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. +import sys +sys.path.append("/data3/lkk/neural-compressor") +import argparse +import json +import logging +import math +import os +import random +from itertools import chain +from pathlib import Path + +import datasets +import torch +import transformers +from datasets import load_dataset +from huggingface_hub import Repository +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import get_full_repo_name +from transformers.utils.versions import require_version + +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DummyOptim, DummyScheduler, set_seed +from neural_compressor.training import prepare_compression +from neural_compressor.training import WeightPruningConfig + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + # New Code # + # Whether to load the best model at the end of training + parser.add_argument( + "--load_best_model", + action="store_true", + help="Whether to load the best model at the end of training", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--teacher_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False + ) + parser.add_argument( + "--distill_loss_weight", + type=float, + default=0.0, + help="distiller loss weight" + ) + + parser.add_argument( + "--cooldown_epochs", + type=int, default=0, + help="Cooling epochs after pruning." + ) + parser.add_argument( + "--do_prune", action="store_true", + help="Whether or not to prune the model" + ) + # parser.add_argument( + # "--keep_conf", action="store_true", + # help="Whether or not to keep the prune config infos" + # ) + parser.add_argument( + "--pruning_pattern", + type=str, default="4x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + parser.add_argument( + "--warm_epochs", + type=int, + default=0, + help="Number of epochs the network not be purned" + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +# New Code # +def evaluate(args, model, eval_dataloader, accelerator, eval_dataset): + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + return perplexity, eval_loss + + +def main(): + args = parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + + # when using DeepSpeed, the `gradient_accumulation_steps` is properly set from the DeepSpeed plugin/config + # or from `accelerate launch` via `--gradient_accumulation_steps` else + # defaulting to the passed `args.gradient_accumulation_steps` + accelerator = ( + Accelerator( + log_with=args.report_to, + project_dir=args.output_dir, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + if args.with_tracking + else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + ) + from accelerate.state import AcceleratorState + accelerator.print(f"{AcceleratorState()}") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + if args.distill_loss_weight > 0: + teacher_path = args.teacher_model_name_or_path + if teacher_path is None: + teacher_path = args.model_name_or_path + teacher_model = AutoModelForQuestionAnswering.from_pretrained( + teacher_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with accelerator.main_process_first(): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with accelerator.main_process_first(): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + # New Code # + # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer + optimizer_cls = ( + torch.optim.AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) + overrode_max_train_steps = False + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # New Code # + # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + else: + lr_scheduler = DummyScheduler( + optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = ( + args.per_device_train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {accelerator.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + best_metric = None + best_metric_checkpoint = None + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + accelerator.load_state(args.resume_from_checkpoint) + accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + path = os.path.basename(args.resume_from_checkpoint) + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + resume_step = int(training_difference.replace("step_", "")) + starting_epoch = resume_step // num_update_steps_per_epoch + resume_step -= starting_epoch * num_update_steps_per_epoch + completed_steps = resume_step + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # Pruning preparation + num_iterations = len(train_dataset) / total_batch_size + num_warm = int(args.warm_epochs * num_iterations) + args.num_warmup_steps + total_iterations = int(num_iterations * (args.num_train_epochs - args.cooldown_epochs)) + frequency = int((total_iterations - num_warm + 1) / 40) if args.pruning_frequency == -1 \ + else args.pruning_frequency + pruning_start = num_warm + pruning_end = total_iterations + if not args.do_prune: + pruning_start = num_iterations * args.num_train_epochs + 1 + pruning_end = pruning_start + pruning_configs=[ + { + "pruning_type": "magnitude", + "pruning_scope": "global", + "sparsity_decay_type": "exp", + "excluded_op_names": ["pooler"], + "pruning_op_types": ["Linear"], + "max_sparsity_ratio_per_op": 0.98 + } + ] + configs = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + pattern=args.pruning_pattern, + pruning_frequency=frequency, + start_step=pruning_start, + end_step=pruning_end + ) + + compression_manager = prepare_compression(model=model, confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + + # skip new `skip_first_batches` to skip the batches when resuming from ckpt + if args.resume_from_checkpoint: + train_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step) + for step, batch in enumerate(train_dataloader): + # In particular, DeepSpeed handles `gradient_accumulation` via `DeepSpeedEngine`. + # Below, we use `accelerator.accumulate` if the user + # wants to switch to other approaches such as plain DDP, PyTorch FSDP ... + # This avoids having to change any code as things are all handled across different distributed setups. + with accelerator.accumulate(model): + compression_manager.callbacks.on_step_begin(step) + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + compression_manager.callbacks.on_before_optimizer_step() + optimizer.step() + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + compression_manager.callbacks.on_after_optimizer_step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + # We keep track of the loss at each epoch + if args.with_tracking: + step_loss = accelerator.reduce(loss.detach().clone()).item() + total_loss += step_loss + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset) + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") + + if args.with_tracking: + accelerator.log( + { + "perplexity": perplexity, + "eval_loss": eval_loss, + "train_loss": total_loss / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if isinstance(checkpointing_steps, str) and checkpointing_steps == "epoch": + accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}")) + + # New Code # + # Tracks the best checkpoint and best metric + if best_metric is None or best_metric > perplexity: + best_metric = perplexity + best_metric_checkpoint = os.path.join(args.output_dir, "best_checkpoint") + accelerator.save_state(best_metric_checkpoint) + accelerator.print(f"New best metric: {best_metric} at epoch {epoch}") + accelerator.print(f"best_metric_checkpoint: {best_metric_checkpoint}") + + compression_manager.callbacks.on_train_end() + + # New Code # + # Loads the best checkpoint after the training is finished + if args.load_best_model: + accelerator.load_state(best_metric_checkpoint) + + # New Code # + # Evaluates using the best checkpoint + perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset) + logger.info(f"Best model metrics: perplexity: {perplexity} eval_loss: {eval_loss}") + if perplexity != best_metric: + raise AssertionError( + f"Best metric {best_metric} does not match the metric {perplexity} of the loaded best model." + ) + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py new file mode 100644 index 00000000000..3f1b021b61e --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py @@ -0,0 +1,839 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import sys +sys.path.append("/data3/lkk/neural-compressor") +import argparse +import json +import logging +import math +import os +import random +from itertools import chain +from pathlib import Path + +import datasets +import torch +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version +from neural_compressor.training import prepare_compression +from neural_compressor.training import WeightPruningConfig + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +# check_min_version("4.33.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--trust_remote_code", + type=bool, + default=False, + help=( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will" + "execute code present on the Hub on your local machine." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help=( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "If passed, LLM loading time and RAM consumption will be benefited." + ), + ) + parser.add_argument( + "--teacher_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False + ) + parser.add_argument( + "--distill_loss_weight", + type=float, + default=0.0, + help="distiller loss weight" + ) + + parser.add_argument( + "--cooldown_epochs", + type=int, default=0, + help="Cooling epochs after pruning." + ) + parser.add_argument( + "--do_prune", action="store_true", + help="Whether or not to prune the model" + ) + # parser.add_argument( + # "--keep_conf", action="store_true", + # help="Whether or not to keep the prune config infos" + # ) + parser.add_argument( + "--pruning_pattern", + type=str, default="4x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + parser.add_argument( + "--warm_epochs", + type=int, + default=0, + help="Number of epochs the network not be purned" + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + # Retrieve of infer repo_name + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + # Create repo and retrieve repo_id + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id + # Clone repo locally + repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + trust_remote_code=args.trust_remote_code, + ) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code + ) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=args.low_cpu_mem_usage, + trust_remote_code=args.trust_remote_code, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code) + + if args.distill_loss_weight > 0: + teacher_path = args.teacher_model_name_or_path + if teacher_path is None: + teacher_path = args.model_name_or_path + teacher_model = AutoModelForQuestionAnswering.from_pretrained( + teacher_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with accelerator.main_process_first(): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + " override this default with `--block_size xxx`." + ) + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with accelerator.main_process_first(): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + if args.do_prune: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, betas=[0.9, 0.9]) + else: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + if args.distill_loss_weight > 0: + teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + teacher_model.eval() + else: + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + checkpoint_path = args.resume_from_checkpoint + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + checkpoint_path = path + path = os.path.basename(checkpoint_path) + + accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") + accelerator.load_state(path) + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + resume_step -= starting_epoch * len(train_dataloader) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # Pruning preparation + num_iterations = len(train_dataset) / total_batch_size + num_warm = int(args.warm_epochs * num_iterations) + args.num_warmup_steps + total_iterations = int(num_iterations * (args.num_train_epochs - args.cooldown_epochs)) + frequency = int((total_iterations - num_warm + 1) / 40) if args.pruning_frequency == -1 \ + else args.pruning_frequency + pruning_start = num_warm + pruning_end = total_iterations + if not args.do_prune: + pruning_start = num_iterations * args.num_train_epochs + 1 + pruning_end = pruning_start + pruning_configs=[ + { + "pruning_type": "magnitude", + "pruning_scope": "global", + "sparsity_decay_type": "exp", + "excluded_op_names": ["pooler"], + "pruning_op_types": ["Linear"], + "max_sparsity_ratio_per_op": 0.98 + } + ] + + configs = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + pattern=args.pruning_pattern, + pruning_frequency=frequency, + start_step=pruning_start, + end_step=pruning_end + ) + + compression_manager = prepare_compression(model=model, confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + with accelerator.accumulate(model): + compression_manager.callbacks.on_step_begin(step) + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + + if args.distill_loss_weight > 0: + distill_loss_weight = args.distill_loss_weight + with torch.no_grad(): + teacher_outputs = teacher_model(**batch) + loss = (distill_loss_weight) / 2 * get_loss_one_logit(outputs['start_logits'], + teacher_outputs['start_logits']) \ + + (distill_loss_weight) / 2 * get_loss_one_logit(outputs['end_logits'], + teacher_outputs['end_logits']) + + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + compression_manager.callbacks.on_before_optimizer_step() + optimizer.step() + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + compression_manager.callbacks.on_after_optimizer_step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps }" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") + + if args.with_tracking: + accelerator.log( + { + "perplexity": perplexity, + "eval_loss": eval_loss, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + compression_manager.callbacks.on_train_end() + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + print(type(model)) + unwrapped_model = accelerator.unwrap_model(model) + print(type(unwrapped_model)) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh new file mode 100644 index 00000000000..a8843077e76 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh @@ -0,0 +1,21 @@ +export CUDA_VISIBLE_DEVICES=5,6 +# python run_clm_no_trainer_pruning.py \ +# python -m torch.distributed.launch \ +# --nproc_per_node=1 \ +# run_clm_no_trainer_pruning.py \ + +accelerate launch --deepspeed_config_file ds_config.json --mixed_precision fp16 \ + run_clm_no_trainer_deepspeed.py \ + --dataset_name ./pile-10k \ + --model_name_or_path /models/opt-125m/ \ + --block_size 128 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 4 \ + --output_dir /tmp/test-clm \ + --num_train_epochs 1 \ + --max_train_steps 100 \ + --do_prune \ + --num_train_epochs 10 \ + --target_sparsity 0.8 \ + --pruning_pattern "4x1" \ + --pruning_frequency 1000 From 8494a20f00fe213116b31198ef221b5fa8c256a5 Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Thu, 31 Aug 2023 14:02:10 +0800 Subject: [PATCH 2/7] added multi-node and single node cpu training support for LLM pruning. Signed-off-by: Xinyu Ye --- .../pruning/eager/run_clm_no_trainer_pruning.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py index 3f1b021b61e..745e9e5d0ff 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py @@ -290,6 +290,11 @@ def parse_args(): default=0, help="Number of epochs the network not be purned" ) + parser.add_argument( + "--cpu", + action="store_true", + help="Whether use cpu for training." + ) args = parser.parse_args() @@ -326,7 +331,7 @@ def main(): accelerator_log_kwargs["log_with"] = args.report_to accelerator_log_kwargs["project_dir"] = args.output_dir - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + accelerator = Accelerator(cpu=args.cpu, gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) # Make one log on every process with the configuration for debugging. logging.basicConfig( From 14e9440b6abec0e0c8b4c89bbbaf16bfc456f7ca Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Thu, 31 Aug 2023 19:38:46 +0800 Subject: [PATCH 3/7] change code directory and add doc. Signed-off-by: Lv, Kaokao --- .../pruning/magnitude/README.md | 90 +++++++++++++++++++ .../config/zero_stage2_config.json} | 0 .../pruning/magnitude/requirements.txt | 8 ++ .../pruning/{eager => magnitude}/run.sh | 0 .../run_clm_no_trainer.py} | 0 .../run_clm_no_trainer_deepspeed.py | 0 .../pruning/{eager => magnitude}/run_ds.sh | 0 7 files changed, 98 insertions(+) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{eager/ds_config.json => magnitude/config/zero_stage2_config.json} (100%) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{eager => magnitude}/run.sh (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{eager/run_clm_no_trainer_pruning.py => magnitude/run_clm_no_trainer.py} (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{eager => magnitude}/run_clm_no_trainer_deepspeed.py (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{eager => magnitude}/run_ds.sh (100%) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md new file mode 100644 index 00000000000..ba5e5f270dd --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md @@ -0,0 +1,90 @@ +Step-by-Step +============ + +# single GPU + +``` +bash run.sh +``` + +# multi GPU + +we use `accelerate` and `deepspeed ZeRO Stage-2` to conduct weight magnitude pruning + +### Accelerate DeepSpeed Plugin + +On your machine(s) just run: +``` +accelerate config +``` + +and answer the questions asked. It will ask whether you want to use a config file for DeepSpeed to which you should answer no. Then answer the following questions to generate a basic DeepSpeed config. This will generate a config file that will be used automatically to properly set the default options when doing + +For instance, + +``` +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_config_file: config/zero_stage2_config.json + zero3_init_flag: true +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 2 +use_cpu: false +``` +with the contents of `config/zero_stage2_config.json` being: + +``` +{ + "train_batch_size": 64, + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0.0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto", + "warmup_type": "cosine" + } + } +} +``` + +### pruning + +``` +bash run_ds.sh +``` diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/config/zero_stage2_config.json similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/ds_config.json rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/config/zero_stage2_config.json diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt new file mode 100644 index 00000000000..8e3abe886e9 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt @@ -0,0 +1,8 @@ +accelerate +datasets +sentencepiece +transformers +torch +tqdm +optimum +einops diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run.sh rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_pruning.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer_deepspeed.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_ds.sh rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh From ced74e11a8b3270484f551fb2ce9502c46925f58 Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Fri, 1 Sep 2023 10:28:30 +0800 Subject: [PATCH 4/7] update code. Signed-off-by: Lv, Kaokao --- .../language-modeling/pruning/magnitude/run.sh | 7 ++++--- .../pruning/magnitude/run_clm_no_trainer.py | 10 ++++++++-- .../pruning/magnitude/run_clm_no_trainer_deepspeed.py | 1 - .../language-modeling/pruning/magnitude/run_ds.sh | 11 ++++------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh index ae50b88f9f4..624df43899a 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh @@ -1,13 +1,14 @@ export CUDA_VISIBLE_DEVICES=0 python run_clm_no_trainer_pruning.py \ - --dataset_name ./pile-10k \ - --model_name_or_path /models/opt-125m/ \ + --dataset_name NeelNanda/pile-10k \ + --model_name_or_path facebook/opt-125m \ --block_size 128 \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 4 \ - --output_dir /tmp/test-clm \ + --output_dir ./test-clm \ --do_prune \ --num_train_epochs 10 \ + --pruning_type "magnitude" \ --target_sparsity 0.8 \ --pruning_pattern "4x1" \ --pruning_frequency 1000 diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py index 745e9e5d0ff..0d357fe6a86 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py @@ -23,7 +23,6 @@ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import sys -sys.path.append("/data3/lkk/neural-compressor") import argparse import json import logging @@ -284,6 +283,13 @@ def parse_args(): type=int, default=-1, help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." ) + parser.add_argument( + "--pruning_type", + type=str, + default="magnitude", + help="pruning criteria to use.", + choices=["magnitude", "snip", "snip_momentum"], + ) parser.add_argument( "--warm_epochs", type=int, @@ -698,7 +704,7 @@ def group_texts(examples): pruning_end = pruning_start pruning_configs=[ { - "pruning_type": "magnitude", + "pruning_type": args.pruning_type, "pruning_scope": "global", "sparsity_decay_type": "exp", "excluded_op_names": ["pooler"], diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py index 996a8163341..cae7b946983 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py @@ -22,7 +22,6 @@ """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import sys -sys.path.append("/data3/lkk/neural-compressor") import argparse import json import logging diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh index a8843077e76..7c0bd0eca58 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh @@ -1,21 +1,18 @@ export CUDA_VISIBLE_DEVICES=5,6 -# python run_clm_no_trainer_pruning.py \ -# python -m torch.distributed.launch \ -# --nproc_per_node=1 \ -# run_clm_no_trainer_pruning.py \ accelerate launch --deepspeed_config_file ds_config.json --mixed_precision fp16 \ run_clm_no_trainer_deepspeed.py \ - --dataset_name ./pile-10k \ - --model_name_or_path /models/opt-125m/ \ + --dataset_name NeelNanda/pile-10k \ + --model_name_or_path facebook/opt-125m \ --block_size 128 \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 4 \ - --output_dir /tmp/test-clm \ + --output_dir ./test-clm \ --num_train_epochs 1 \ --max_train_steps 100 \ --do_prune \ --num_train_epochs 10 \ + --pruning_type "magnitude" \ --target_sparsity 0.8 \ --pruning_pattern "4x1" \ --pruning_frequency 1000 From b7f47dc6d55120649cec5f12e54573802ca9b03a Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Fri, 1 Sep 2023 10:35:50 +0800 Subject: [PATCH 5/7] add pruning type args. Signed-off-by: Lv, Kaokao --- .../pruning/magnitude/run_clm_no_trainer_deepspeed.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py index cae7b946983..c8c86fdd971 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py @@ -267,6 +267,13 @@ def parse_args(): type=int, default=-1, help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." ) + parser.add_argument( + "--pruning_type", + type=str, + default="magnitude", + help="pruning criteria to use.", + choices=["magnitude", "snip", "snip_momentum"], + ) parser.add_argument( "--warm_epochs", type=int, @@ -680,7 +687,7 @@ def group_texts(examples): pruning_end = pruning_start pruning_configs=[ { - "pruning_type": "magnitude", + "pruning_type": args.pruning_type, "pruning_scope": "global", "sparsity_decay_type": "exp", "excluded_op_names": ["pooler"], From bfe17ab555467c35e8e4a1a085359338e762c264 Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Fri, 1 Sep 2023 11:11:14 +0800 Subject: [PATCH 6/7] update code. Signed-off-by: Lv, Kaokao --- .../pruning/magnitude/README.md | 21 +++- .../pruning/magnitude/run.sh | 101 ++++++++++++++--- .../pruning/magnitude/run_ds.sh | 106 +++++++++++++++--- 3 files changed, 194 insertions(+), 34 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md index ba5e5f270dd..b0d049f9547 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md @@ -4,7 +4,15 @@ Step-by-Step # single GPU ``` -bash run.sh +export CUDA_VISIBLE_DEVICES=0 +bash run.sh \ + --model_name_or_path=facebook/opt-125m \ + --dataset_name=NeelNanda/pile-10k \ + --block_size=128 \ + --output_dir=./test-clm \ + --pruning_type=magnitude \ + --pruning_pattern=4x1 \ + --pruning_frequency=1000 ``` # multi GPU @@ -86,5 +94,14 @@ with the contents of `config/zero_stage2_config.json` being: ### pruning ``` -bash run_ds.sh +# 2 gpu cards example +export CUDA_VISIBLE_DEVICES=0,1 +bash run_ds.sh \ + --model_name_or_path=facebook/opt-125m \ + --dataset_name=NeelNanda/pile-10k \ + --block_size=128 \ + --output_dir=./test-clm \ + --pruning_type=magnitude \ + --pruning_pattern=4x1 \ + --pruning_frequency=1000 ``` diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh index 624df43899a..7b70df66f57 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh @@ -1,14 +1,87 @@ -export CUDA_VISIBLE_DEVICES=0 -python run_clm_no_trainer_pruning.py \ - --dataset_name NeelNanda/pile-10k \ - --model_name_or_path facebook/opt-125m \ - --block_size 128 \ - --per_device_train_batch_size 8 \ - --gradient_accumulation_steps 4 \ - --output_dir ./test-clm \ - --do_prune \ - --num_train_epochs 10 \ - --pruning_type "magnitude" \ - --target_sparsity 0.8 \ - --pruning_pattern "4x1" \ - --pruning_frequency 1000 +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + dataset_name="NeelNanda/pile-10k" + model_name_or_path="facebook/opt-125m" + output_dir="./test-clm" + per_device_train_batch_size=8 + block_size=128 + gradient_accumulation_steps=4 + num_train_epochs=3 + target_sparsity=0.8 + pruning_type="magnitude" + pruning_pattern="4x1" + pruning_frequency=1000 + for var in "$@" + do + case $var in + --dataset_name=*) + dataset_name=$(echo $var |cut -f2 -d=) + ;; + --model_name_or_path=*) + model_name_or_path=$(echo $var |cut -f2 -d=) + ;; + --output_dir=*) + output_dir=$(echo $var |cut -f2 -d=) + ;; + --per_device_train_batch_size=*) + per_device_train_batch_size=$(echo $var |cut -f2 -d=) + ;; + --block_size=*) + block_size=$(echo $var |cut -f2 -d=) + ;; + --gradient_accumulation_steps=*) + gradient_accumulation_steps=$(echo $var |cut -f2 -d=) + ;; + --num_train_epochs=*) + num_train_epochs=$(echo $var |cut -f2 -d=) + ;; + --target_sparsity=*) + target_sparsity=$(echo $var |cut -f2 -d=) + ;; + --pruning_type=*) + pruning_type=$(echo $var |cut -f2 -d=) + ;; + --pruning_pattern=*) + pruning_pattern=$(echo $var |cut -f2 -d=) + ;; + --pruning_frequency=*) + pruning_frequency=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python run_clm_no_trainer_pruning.py \ + --dataset_name $dataset_name \ + --model_name_or_path $model_name_or_path \ + --block_size $block_size \ + --per_device_train_batch_size $per_device_train_batch_size \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --output_dir $output_dir \ + --do_prune \ + --pruning_pattern $pruning_type \ + --num_train_epochs $num_train_epochs \ + --target_sparsity $target_sparsity \ + --pruning_pattern $pruning_pattern \ + --pruning_frequency $pruning_frequency + +} + +main "$@" diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh index 7c0bd0eca58..1d646f19804 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh @@ -1,18 +1,88 @@ -export CUDA_VISIBLE_DEVICES=5,6 - -accelerate launch --deepspeed_config_file ds_config.json --mixed_precision fp16 \ - run_clm_no_trainer_deepspeed.py \ - --dataset_name NeelNanda/pile-10k \ - --model_name_or_path facebook/opt-125m \ - --block_size 128 \ - --per_device_train_batch_size 8 \ - --gradient_accumulation_steps 4 \ - --output_dir ./test-clm \ - --num_train_epochs 1 \ - --max_train_steps 100 \ - --do_prune \ - --num_train_epochs 10 \ - --pruning_type "magnitude" \ - --target_sparsity 0.8 \ - --pruning_pattern "4x1" \ - --pruning_frequency 1000 +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + dataset_name="NeelNanda/pile-10k" + model_name_or_path="facebook/opt-125m" + output_dir="./test-clm" + per_device_train_batch_size=8 + block_size=128 + gradient_accumulation_steps=4 + num_train_epochs=3 + target_sparsity=0.8 + pruning_type="magnitude" + pruning_pattern="4x1" + pruning_frequency=1000 + for var in "$@" + do + case $var in + --dataset_name=*) + dataset_name=$(echo $var |cut -f2 -d=) + ;; + --model_name_or_path=*) + model_name_or_path=$(echo $var |cut -f2 -d=) + ;; + --output_dir=*) + output_dir=$(echo $var |cut -f2 -d=) + ;; + --per_device_train_batch_size=*) + per_device_train_batch_size=$(echo $var |cut -f2 -d=) + ;; + --block_size=*) + block_size=$(echo $var |cut -f2 -d=) + ;; + --gradient_accumulation_steps=*) + gradient_accumulation_steps=$(echo $var |cut -f2 -d=) + ;; + --num_train_epochs=*) + num_train_epochs=$(echo $var |cut -f2 -d=) + ;; + --target_sparsity=*) + target_sparsity=$(echo $var |cut -f2 -d=) + ;; + --pruning_type=*) + pruning_type=$(echo $var |cut -f2 -d=) + ;; + --pruning_pattern=*) + pruning_pattern=$(echo $var |cut -f2 -d=) + ;; + --pruning_frequency=*) + pruning_frequency=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_tuning { + accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \ + run_clm_no_trainer_deepspeed.py \ + --dataset_name $dataset_name \ + --model_name_or_path $model_name_or_path \ + --block_size $block_size \ + --per_device_train_batch_size $per_device_train_batch_size \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --output_dir $output_dir \ + --do_prune \ + --num_train_epochs $num_train_epochs \ + --target_sparsity $target_sparsity \ + --pruning_pattern $pruning_pattern \ + --pruning_frequency $pruning_frequency + +} + +main "$@" + From 85cd6c551bbf2aa40fe81d9f4066da3bde10edac Mon Sep 17 00:00:00 2001 From: "Lv, Kaokao" Date: Fri, 1 Sep 2023 11:25:01 +0800 Subject: [PATCH 7/7] update script. Signed-off-by: Lv, Kaokao --- .../language-modeling/pruning/magnitude/run.sh | 6 +++--- .../language-modeling/pruning/magnitude/run_ds.sh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh index 7b70df66f57..c1eb360d758 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh @@ -4,7 +4,7 @@ set -x function main { init_params "$@" - run_tuning + run_pruning } @@ -67,8 +67,8 @@ function init_params { } # run_tuning -function run_tuning { - python run_clm_no_trainer_pruning.py \ +function run_pruning { + python run_clm_no_trainer.py \ --dataset_name $dataset_name \ --model_name_or_path $model_name_or_path \ --block_size $block_size \ diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh index 1d646f19804..b8bad7b8bb2 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh @@ -4,7 +4,7 @@ set -x function main { init_params "$@" - run_tuning + run_pruning } @@ -67,7 +67,7 @@ function init_params { } # run_tuning -function run_tuning { +function run_pruning { accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \ run_clm_no_trainer_deepspeed.py \ --dataset_name $dataset_name \