From d29aa0ff3383e5c93740b82b8ddaab02d45d2e97 Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 <109071285+WeiweiZhang1@users.noreply.github.com> Date: Wed, 12 Apr 2023 10:58:13 +0800 Subject: [PATCH] add block_mask & retrain_free features (#775) Signed-off-by: Zhang, Weiwei1 Co-authored-by: wenhuach21 <108330088+wenhuach21@users.noreply.github.com> --- .../pruning/eager/run_clm_no_trainer.py | 924 +++++++++++++ .../pruning/eager/run_llm_pruning.sh | 17 + .../language-modeling/pruning/eager/timers.py | 33 + .../pruning/eager/run_qa_no_trainer_block.py | 1181 +++++++++++++++++ .../compression/pruner/criteria.py | 87 +- .../compression/pruner/patterns.py | 232 ++-- .../compression/pruner/pruners.py | 333 ++++- neural_compressor/compression/pruner/utils.py | 30 +- test/pruning_2.x/test_pruning_block.py | 88 ++ test/pruning_2_plus.x/test_pruning_block.py | 75 ++ 10 files changed, 2878 insertions(+), 122 deletions(-) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_llm_pruning.sh create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/timers.py create mode 100644 examples/pytorch/nlp/huggingface_models/question-answering/pruning/eager/run_qa_no_trainer_block.py create mode 100644 test/pruning_2.x/test_pruning_block.py create mode 100644 test/pruning_2_plus.x/test_pruning_block.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py new file mode 100644 index 00000000000..db74cfca5e8 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py @@ -0,0 +1,924 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import argparse +import json +import logging +import math +import os +import sys +sys.path.insert(0, './neural-compressor') +sys.path.insert(0, './') + +import random +from itertools import chain +from pathlib import Path + +import datasets +import torch +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from torch.nn.functional import pad + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry +from transformers.utils.versions import require_version +from neural_compressor.training import prepare_compression +from neural_compressor.training import WeightPruningConfig +from timers import CPUTimer, GPUTimer +from neural_compressor.compression import model_slim +from neural_compressor.compression import parse_auto_slim_config +set_seed(42) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.23.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +class Evaluator: + def __init__(self, dataset, tokenizer, device, batch_size=16): + self.dataset = dataset + self.tokenizer = tokenizer + self.device = device + self.dataloader = INCDataloader(dataset, tokenizer, self.device, batch_size) + + @torch.no_grad() + def evaluate(self, model): + model.eval() + # The task is to predict the last word of the input. + total, hit = 0, 0 + if torch.cuda.is_available(): + my_timer = GPUTimer(timelogs = []) + else: + my_timer = CPUTimer(timelogs = []) + warmup_steps = 10 + step = 0 + for input_ids, label, label_indices in tqdm(self.dataloader): + with torch.no_grad(): + # if step == 0: + # model = torch.jit.trace(model, input_ids) + step += 1 + # timing + if step > warmup_steps: my_timer.__enter__() + outputs = model(input_ids) + # outputs = model(input_ids, labels=label) + if step > warmup_steps: my_timer.__exit__() + last_token_logits = outputs[0][torch.arange(len(label_indices)), label_indices, :] + pred = last_token_logits.argmax(dim=-1) + total += label.size(0) + hit += (pred == label).sum().item() + if step % 100 == 0: + logger.info(f"eval step:{step} accuracy:{float(hit/total)}") + avg_latency = my_timer.get_avg_time() + del my_timer + accuracy = hit / total + return accuracy, avg_latency + + +class INCDataloader(): + def __init__(self, dataset, tokenizer, device, batch_size=1): + self.dataset = dataset + self.tokenizer = tokenizer + self.device = device + self.batch_size = batch_size + import math + self.length = math.ceil(len(dataset) / self.batch_size) + self.pad_len = 196 + + # tokenize the dataset + def tokenize_function(examples): + example = self.tokenizer(examples['text']) + return example + + self.dataset = self.dataset.map(tokenize_function, batched=True) + self.dataset.set_format(type='torch', columns=['input_ids']) + + def pad_input(self, input): + input_id = input['input_ids'].unsqueeze(0).to(self.device) + label = input_id[:, -1].to(self.device) + pad_len = self.pad_len - input_id.shape[1] + label_index = -2 - pad_len + input_id = pad(input_id, (0, pad_len), value=1) + + return (input_id, label, label_index) + + def __iter__(self): + input_ids = None + labels = None + label_indices = None + for idx, batch in enumerate(self.dataset): + input_id, label, label_index = self.pad_input(batch) + + if input_ids is None: + input_ids = input_id + labels = label + label_indices = [label_index] + else: + input_ids = torch.cat((input_ids, input_id), 0).to(self.device) + labels = torch.cat((labels, label), 0).to(self.device) + label_indices.append(label_index) + + if (idx + 1) % self.batch_size == 0: + yield (input_ids, labels, label_indices) + input_ids = None + labels = None + label_indices = None + if (idx + 1) % self.batch_size != 0: + yield (input_ids, labels, label_indices) + + def __len__(self): + return self.length + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help=( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "If passed, LLM loading time and RAM consumption will be benefited." + ), + ) + # pruning config + parser.add_argument( + "--cooldown_epochs", + type=int, default=0, + help="Cooling epochs after pruning." + ) + parser.add_argument( + "--do_prune", action="store_true", + help="Whether or not to prune the model" + ) + parser.add_argument( + "--pruning_pattern", + type=str, default="4x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + parser.add_argument( + "--auto_slim", action="store_true", + help="Whether or not to auto slim the model after pruning." + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["logging_dir"] = args.output_dir + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + # raw_datasets = load_dataset(args.dataset_name, keep_in_memory=True) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name, torchscript=True) + elif args.model_name_or_path: + # torchscript will force `return_dict=False` to avoid jit errors + config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=True) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + is_llama = bool("llama" in args.model_name_or_path) + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) + elif args.model_name_or_path: + if is_llama: + tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path) + else : + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=args.low_cpu_mem_usage, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + # return tokenizer(examples[text_column_name], max_length=512, truncation=True) #padding + return tokenizer(examples[text_column_name]) + + + with accelerator.main_process_first(): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + " override this default with `--block_size xxx`." + ) + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with accelerator.main_process_first(): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataset = eval_dataset.shuffle(seed=42) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + # eval_dataloader = eval_dataloader.shuffle(seed=42) + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + accelerator.load_state(args.resume_from_checkpoint) + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + + # update the progress_bar if load from checkpoint + progress_bar.update(starting_epoch * num_update_steps_per_epoch) + completed_steps = starting_epoch * num_update_steps_per_epoch + + # Pruning preparation + num_iterations = num_update_steps_per_epoch + num_warm = args.num_warmup_steps + total_iterations = args.max_train_steps + frequency = int((total_iterations - num_warm + 1) / 40) if args.pruning_frequency == -1 \ + else args.pruning_frequency + pruning_start = max(num_warm, 1) + pruning_end = total_iterations + if not args.do_prune: + pruning_start = num_iterations * args.num_train_epochs + 1 + pruning_end = pruning_start + + if is_llama or not args.auto_slim: + pruning_configs=[ + { + "pruning_type": "retrain_free", + "pruning_scope": "global", + # "op_names": ["fc_out"], #for gptj + # "op_names": ["down_proj"], #for llama + "op_names": ["fc2"], #for opt + "excluded_op_names": [".attn"], + "sparsity_decay_type": "exp", + "pattern": "channelx1", + "pruning_op_types": ["Linear"], + "max_sparsity_ratio_per_op": 0.98, + } + ] + else: + # auto slim config + pruning_configs=[] + auto_slim_configs = parse_auto_slim_config( + model, + ffn2_sparsity = args.target_sparsity, + mha_sparsity = 0, + pruning_scope = "global", + pruning_type = "retrain_free", + ) + pruning_configs += auto_slim_configs + + configs = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + # pattern=args.pruning_pattern, + pruning_frequency=frequency, + start_step=pruning_start, + end_step=pruning_end, + ) + compression_manager = prepare_compression(model=model, confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model.model + + for epoch in range(starting_epoch, args.num_train_epochs): + # model.train() + model.eval() + if args.with_tracking: + total_loss = 0 + for step, batch in enumerate(train_dataloader): + # We need to skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == starting_epoch: + if resume_step is not None and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + completed_steps += 1 + continue + compression_manager.callbacks.on_step_begin(step) + with accelerator.accumulate(model): + outputs = model(return_dict=True, **batch) + # outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + accelerator.backward(loss) + compression_manager.callbacks.on_before_optimizer_step() + # optimizer.step() + compression_manager.callbacks.on_after_optimizer_step() + # lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + dataset_eval = raw_datasets["validation"] + evaluator = Evaluator(dataset_eval, tokenizer, model.device, batch_size=args.per_device_eval_batch_size) + def eval_func(model): + acc, avg_latency = evaluator.evaluate(model) + return acc, avg_latency + + # losses = [] + # total, hit = 0, 0 + # for step, batch in enumerate(eval_dataloader): + # with torch.no_grad(): + # # labels = batch['labels'] + # input_ids = batch['input_ids'] + # labels = input_ids[:, -1].to(input_ids.device) + # outputs = model(return_dict=True, **batch) + # # outputs = model(**batch) + # last_token_logits = outputs['logits'][:,-2,:] + # # last_token_logits = outputs['logits'] + # pred = last_token_logits.argmax(dim=-1) + # total += labels.size(0) + # hit += (pred == labels).sum().item() + + # loss = outputs.loss + # losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + + # acc = hit / total + # losses = torch.cat(losses) + # try: + # eval_loss = torch.mean(losses) + # perplexity = math.exp(eval_loss) + # except OverflowError: + # perplexity = float("inf") + + # logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss} accuracy:{acc}") + # if args.with_tracking: + # accelerator.log( + # { + # "perplexity": perplexity, + # "eval_loss": eval_loss, + # "train_loss": total_loss.item() / len(train_dataloader), + # "epoch": epoch, + # "step": completed_steps, + # }, + # step=completed_steps, + # ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + compression_manager.callbacks.on_train_end() + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir+"_noslim", is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir+"_noslim") + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + if is_llama or not args.auto_slim: + # only eval + logger.info(f"***** Running Evaluation *****") + acc, _ = eval_func(model) + logger.info(f"total_steps:{completed_steps} accuracy:{acc}") + else: + logger.info(f"***** Running Evaluation before ffn auto slim*****") + accuracy, avg_latency = eval_func(model) + logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}") + model = model_slim(model, round_multiplier=32) + + logger.info(f"***** Running Evaluation after ffn auto_slim*****") + accuracy, avg_latency = eval_func(model) + logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}") + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir+"_slimed", is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir+"_slimed") + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + # with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + # json.dump({"perplexity": perplexity}, f) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_llm_pruning.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_llm_pruning.sh new file mode 100644 index 00000000000..25e42fac3dd --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_llm_pruning.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -x + python \ + examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py \ + --model_name_or_path /path/to/your/model \ + --dataset_name lambada \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 16 \ + --max_train_steps 3002 \ + --weight_decay 0 \ + --block_size 512 \ + --do_prune \ + --auto_slim \ + --output_dir sparse_clm_models/ \ + --target_sparsity 0.2 \ + --pruning_pattern channelx1 \ + --pruning_frequency 500 \ diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/timers.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/timers.py new file mode 100644 index 00000000000..ee4999c460a --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/timers.py @@ -0,0 +1,33 @@ +import time +import torch +class CPUTimer: + def __init__(self, timelogs): + self.timelogs = timelogs + + def __enter__(self): + self.start = time.time() + + def __exit__(self): + end = time.time() + self.timelogs.append((end - self.start) * 1000) # ms + + def get_avg_time(self): + return sum(self.timelogs) / len(self.timelogs) + +class GPUTimer: + def __init__(self, timelogs): + self.timelogs = timelogs + + def __enter__(self): + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) + self.start_event.record() + + def __exit__(self): + self.end_event.record() + self.end_event.synchronize() + elapsed_time = self.start_event.elapsed_time(self.end_event) + self.timelogs.append(elapsed_time) + + def get_avg_time(self): + return sum(self.timelogs) / len(self.timelogs) \ No newline at end of file diff --git a/examples/pytorch/nlp/huggingface_models/question-answering/pruning/eager/run_qa_no_trainer_block.py b/examples/pytorch/nlp/huggingface_models/question-answering/pruning/eager/run_qa_no_trainer_block.py new file mode 100644 index 00000000000..0e93aa7d1f3 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/question-answering/pruning/eager/run_qa_no_trainer_block.py @@ -0,0 +1,1181 @@ +# !/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning a 🤗 Transformers model for question answering using 🤗 Accelerate. +""" +# You can also adapt this script on your own question answering task. Pointers for this are left as comments. + +import argparse +import json +import logging +import math +import os +import sys + +sys.path.insert(0, './neural-compressor') +sys.path.insert(0, './') +from pathlib import Path + +import datasets +import numpy as np +import torch +from datasets import load_dataset, load_metric +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from huggingface_hub import Repository +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForQuestionAnswering, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version +from transformers.utils.versions import require_version +from utils_qa import postprocess_qa_predictions +from neural_compressor.training import prepare_compression +from neural_compressor.training import WeightPruningConfig + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.21.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") + +logger = get_logger(__name__) +# You should update this to your particular problem to have better documentation of `model_type` +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +# (['loss', 'start_logits', 'end_logits']) +# batch(['attention_mask', 'end_positions', 'input_ids', 'start_positions', 'token_type_ids'] +def get_loss_one_logit(student_logit, teacher_logit): + t = 2.0 + from torch.nn import functional as F + return F.kl_div( + input=F.log_softmax(student_logit / t, dim=-1), + target=F.softmax(teacher_logit / t, dim=-1), + reduction="batchmean" + ) * (t ** 2) + +def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"): + """ + Save results while prefixing metric names. + + Args: + results: (:obj:`dict`): + A dictionary of results. + output_dir: (:obj:`str`): + An output directory. + file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`): + An output file name. + metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`): + A metric name prefix. + """ + # Prefix all keys with metric_key_prefix + '_' + for key in list(results.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + results[f"{metric_key_prefix}_{key}"] = results.pop(key) + + with open(os.path.join(output_dir, file_name), "w") as f: + json.dump(results, f, indent=4) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, default=4, + help="A csv or a json file containing the training data." + ) + + parser.add_argument( + "--do_predict", + action="store_true", + help="To do prediction on the question answering model" + ) + parser.add_argument( + "--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--test_file", + type=str, + default=None, + help="A csv or a json file containing the Prediction data." + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=384, + help=( + "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," + " sequences shorter will be padded if `--pad_to_max_lengh` is passed." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models." + ) + parser.add_argument( + "--teacher_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--distill_loss_weight", + type=float, + default=0.0, + help="distiller loss weight" + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.0, + help="Weight decay to use." + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=3, + help="Total number of training epochs to perform." + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--warm_epochs", + type=int, + default=0, + help="Number of epochs the network not be purned" + ) + parser.add_argument( + "--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Where to store the final model." + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--doc_stride", + type=int, + default=128, + help="When splitting up a long document into chunks how much stride to take between chunks.", + ) + parser.add_argument( + "--n_best_size", + type=int, + default=20, + help="The total number of n-best predictions to generate when looking for an answer.", + ) + parser.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help=( + "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + ), + ) + parser.add_argument( + "--version_2_with_negative", + action="store_true", + help="If true, some of the examples do not have an answer.", + ) + parser.add_argument( + "--max_answer_length", + type=int, + default=30, + help=( + "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--max_eval_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ), + ) + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--max_predict_samples", + type=int, + default=None, + help="For debugging purposes or quicker training, truncate the number of prediction examples to this", + ) + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub." + ) + parser.add_argument( + "--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument( + "--hub_token", + type=str, + help="The token to use to push to the Model Hub." + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + + parser.add_argument( + "--cooldown_epochs", + type=int, default=0, + help="Cooling epochs after pruning." + ) + parser.add_argument( + "--do_prune", action="store_true", + help="Whether or not to prune the model" + ) + # parser.add_argument( + # "--keep_conf", action="store_true", + # help="Whether or not to keep the prune config infos" + # ) + parser.add_argument( + "--pruning_pattern", + type=str, default="4x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + + args = parser.parse_args() + + # Sanity checks + if ( + args.dataset_name is None + and args.train_file is None + and args.validation_file is None + and args.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation/test file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if args.test_file is not None: + extension = args.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + # send_example_telemetry("run_qa_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator = ( + Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() + ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + else: + data_files = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + if args.test_file is not None: + data_files["test"] = args.test_file + extension = args.train_file.split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, field="data") + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.distill_loss_weight > 0: + teacher_path = args.teacher_model_name_or_path + if teacher_path is None: + teacher_path = args.model_name_or_path + teacher_model = AutoModelForQuestionAnswering.from_pretrained( + teacher_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + + if args.model_name_or_path: + model = AutoModelForQuestionAnswering.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForQuestionAnswering.from_config(config) + + # Preprocessing the datasets. + # Preprocessing is slighlty different for training and evaluation. + + column_names = raw_datasets["train"].column_names + + question_column_name = "question" if "question" in column_names else column_names[0] + context_column_name = "context" if "context" in column_names else column_names[1] + answer_column_name = "answers" if "answers" in column_names else column_names[2] + + # Padding side determines if we do (question|context) or (context|question). + pad_on_right = tokenizer.padding_side == "right" + + if args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + + max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + # Training preprocessing + def prepare_train_features(examples): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples[answer_column_name][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if args.max_train_samples is not None: + # We will select sample from whole data if agument is specified + train_dataset = train_dataset.select(range(args.max_train_samples)) + + # Create train feature from dataset + with accelerator.main_process_first(): + train_dataset = train_dataset.map( + prepare_train_features, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + if args.max_train_samples is not None: + # Number of samples might increase during Feature Creation, We select only specified max samples + train_dataset = train_dataset.select(range(args.max_train_samples)) + + # Validation preprocessing + def prepare_validation_features(examples): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_examples = raw_datasets["validation"] + if args.max_eval_samples is not None: + # We will select sample from whole data + eval_examples = eval_examples.select(range(args.max_eval_samples)) + # Validation Feature Creation + with accelerator.main_process_first(): + eval_dataset = eval_examples.map( + prepare_validation_features, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + + if args.max_eval_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + eval_dataset = eval_dataset.select(range(args.max_eval_samples)) + + if args.do_predict: + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_examples = raw_datasets["test"] + if args.max_predict_samples is not None: + # We will select sample from whole data + predict_examples = predict_examples.select(range(args.max_predict_samples)) + # Predict Feature Creation + with accelerator.main_process_first(): + predict_dataset = predict_examples.map( + prepare_validation_features, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + if args.max_predict_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + predict_dataset = predict_dataset.select(range(args.max_predict_samples)) + + # # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + if args.pad_to_max_length: + # If padding was already done ot max length, we use the default data collator that will just convert everything + # to tensors. + data_collator = default_data_collator + else: + # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of + # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple + # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) + + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size + ) + + eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"]) + eval_dataloader = DataLoader( + eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size + ) + + if args.do_predict: + predict_dataset_for_model = predict_dataset.remove_columns(["example_id", "offset_mapping"]) + predict_dataloader = DataLoader( + predict_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Post-processing: + def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=args.version_2_with_negative, + n_best_size=args.n_best_size, + max_answer_length=args.max_answer_length, + null_score_diff_threshold=args.null_score_diff_threshold, + output_dir=args.output_dir, + prefix=stage, + ) + # Format the result to the format the metric expects. + if args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + metric = load_metric("squad_v2" if args.version_2_with_negative else "squad") + + # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor + def create_and_fill_np_array(start_or_end_logits, dataset, max_len): + """ + Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor + Args: + start_or_end_logits(:obj:`tensor`): + This is the output predictions of the model. We can only enter either start or end logits. + eval_dataset: Evaluation dataset + max_len(:obj:`int`): + The maximum length of the output tensor. ( See the model.eval() part for more details ) + """ + + step = 0 + # create a numpy array and fill it with -100. + logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) + # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather + for i, output_logit in enumerate(start_or_end_logits): # populate columns + # We have to fill it such that we have to take the whole tensor and replace it on the newly created array + # And after every iteration we have to change the step + + batch_size = output_logit.shape[0] + cols = output_logit.shape[1] + + if step + batch_size < len(dataset): + logits_concat[step: step + batch_size, :cols] = output_logit + else: + logits_concat[step:, :cols] = output_logit[: len(dataset) - step] + + step += batch_size + + return logits_concat + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + no_decay_outputs = ["bias", "LayerNorm.weight", "qa_outputs"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if args.do_prune: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, betas=[0.9, 0.9]) + else: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + if args.distill_loss_weight > 0: + teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + teacher_model.eval() + else: + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + if hasattr(args.checkpointing_steps, "isdigit"): + checkpointing_steps = args.checkpointing_steps + if args.checkpointing_steps.isdigit(): + checkpointing_steps = int(args.checkpointing_steps) + else: + checkpointing_steps = None + + # We need to initialize the trackers we use, and also store our configuration. + # We initialize the trackers only on main process because `accelerator.log` + # only logs on main process and we don't want empty logs/runs on other processes. + if args.with_tracking: + if accelerator.is_main_process: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("qa_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + accelerator.load_state(args.resume_from_checkpoint) + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + else: + resume_step = int(training_difference.replace("step_", "")) + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + + # Pruning preparation + num_iterations = len(train_dataset) / total_batch_size + num_warm = int(args.warm_epochs * num_iterations) + args.num_warmup_steps + total_iterations = int(num_iterations * (args.num_train_epochs - args.cooldown_epochs)) + frequency = int((total_iterations - num_warm + 1) / 40) if args.pruning_frequency == -1 \ + else args.pruning_frequency + pruning_start = num_warm + pruning_end = total_iterations + if not args.do_prune: + pruning_start = num_iterations * args.num_train_epochs + 1 + pruning_end = pruning_start + pruning_configs=[ + { + "pruning_type": "block_mask", + "pruning_scope": "global", + "criterion_type": "snip_momentum_block", + "excluded_op_names": ["qa_outputs", "pooler", ".*embeddings*"], + "sparsity_decay_type": "exp", + "pruning_op_types": ["Linear"], + "max_sparsity_ratio_per_op": 0.98 + } + ] + configs = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + pattern=args.pruning_pattern, + pruning_frequency=frequency, + start_step=pruning_start, + end_step=pruning_end + ) + compression_manager = prepare_compression(model=model, confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if epoch >= args.warm_epochs: + if args.with_tracking: + total_loss = 0 + for step, batch in enumerate(train_dataloader): + # pruner.on_step_begin(local_step=step) + compression_manager.callbacks.on_step_begin(step) + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + if args.distill_loss_weight > 0: + distill_loss_weight = args.distill_loss_weight + with torch.no_grad(): + teacher_outputs = teacher_model(**batch) + loss = (distill_loss_weight) / 2 * get_loss_one_logit(outputs['start_logits'], + teacher_outputs['start_logits']) \ + + (distill_loss_weight) / 2 * get_loss_one_logit(outputs['end_logits'], + teacher_outputs['end_logits']) + + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + compression_manager.callbacks.on_before_optimizer_step() + optimizer.step() + compression_manager.callbacks.on_after_optimizer_step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if completed_steps >= args.max_train_steps: + break + else: + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + # unwrapped_model = accelerator.unwrap_model(model) + # unwrapped_model.save_pretrained( + # args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + # ) + accelerator.save_state(args.output_dir) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + # eval each epoch + logger.info(f"***** Running Evaluation*****") + all_start_logits = [] + all_end_logits = [] + + model.eval() + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + start_logits = outputs.start_logits + end_logits = outputs.end_logits + + if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered + start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) + end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100) + + all_start_logits.append(accelerator.gather(start_logits).cpu().numpy()) + all_end_logits.append(accelerator.gather(end_logits).cpu().numpy()) + + max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor + + # concatenate the numpy array + start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) + end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) + + # delete the list of numpy arrays + del all_start_logits + del all_end_logits + + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) + eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) + logger.info(f"Evaluation metrics of epoch{epoch}: {eval_metric}") + + compression_manager.callbacks.on_train_end() + # Prediction + if args.do_predict: + logger.info("***** Running Prediction *****") + logger.info(f" Num examples = {len(predict_dataset)}") + logger.info(f" Batch size = {args.per_device_eval_batch_size}") + + all_start_logits = [] + all_end_logits = [] + + model.eval() + + for step, batch in enumerate(predict_dataloader): + with torch.no_grad(): + outputs = model(**batch) + start_logits = outputs.start_logits + end_logits = outputs.end_logits + + if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered + start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) + end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100) + + all_start_logits.append(accelerator.gather(start_logits).cpu().numpy()) + all_end_logits.append(accelerator.gather(end_logits).cpu().numpy()) + + max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor + # concatenate the numpy array + start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len) + end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len) + + # delete the list of numpy arrays + del all_start_logits + del all_end_logits + + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy) + predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) + logger.info(f"Predict metrics: {predict_metric}") + + if args.with_tracking: + log = { + "squad_v2" if args.version_2_with_negative else "squad": eval_metric, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + } + if args.do_predict: + log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric + + accelerator.log(log, step=completed_steps) + + if args.output_dir is not None: + accelerator.wait_for_everyone() + # unwrapped_model = accelerator.unwrap_model(model) + # unwrapped_model.save_pretrained( + # args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + # ) + accelerator.save_state(args.output_dir) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + logger.info(json.dumps(eval_metric, indent=4)) + save_prefixed_metrics(eval_metric, args.output_dir) + + +if __name__ == "__main__": + main() + + + diff --git a/neural_compressor/compression/pruner/criteria.py b/neural_compressor/compression/pruner/criteria.py index cf353d1160f..4b69fe8bd38 100644 --- a/neural_compressor/compression/pruner/criteria.py +++ b/neural_compressor/compression/pruner/criteria.py @@ -45,7 +45,7 @@ class PruningCriterion: Args: config: A config dict object that includes information about pruner and pruning criterion. modules: A dict {"module_name": Tensor} that stores the pruning modules' weights. - + Attributes: scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ @@ -190,3 +190,88 @@ def on_before_optimizer_step(self): p = self.modules[key].weight self.scores[key] *= self.alpha self.scores[key] += self.beta * torch.abs(p * p.grad) + + +@register_criterion('snip_momentum_block') +class SnipMomentumBlockCriterion(PruningCriterion): + """Pruning criterion. + + The snip_momentum_block criterion_class is derived from PruningCriterion. + A momentum mechanism is used to calculate snip score, which determines if a block of weights is to be pruned. + + Args: + config: A config dict object that includes information about pruner and pruning criterion. + modules: A dict {"module_name": Tensor} that stores the pruning modules' weights. + alpha: A parameter that determines how much of the snip score is preserved from last pruning step. + beta: A parameter that determines how much of the snip score is updated at the current step. + + Attributes: + scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. + """ + + def __init__(self, modules, config): + """Initiliaze a block_mask pruning criterion.""" + super(SnipMomentumBlockCriterion, self).__init__(modules, config) + assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask = self.modules[key].block_mask + self.scores[key] = torch.zeros(mask.shape).to(mask.device) + self.alpha = 0.9 + self.beta = 1.0 + + def on_before_optimizer_step(self): + """Calculate and store the pruning scores based on snip_momentum_block criterion.""" + with torch.no_grad(): + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask = self.modules[key].block_mask + self.scores[key] *= self.alpha + self.scores[key] += self.beta * torch.abs(mask.grad) + + +@register_criterion('retrain_free') +class RetrainFreeCriterion(PruningCriterion): + """Pruning criterion. + + The retrain_free criterion_class is derived from PruningCriterion. + + Args: + config: A config dict object that includes information about pruner and pruning criterion. + modules: A dict {"module_name": Tensor} that stores the pruning modules' weights. + alpha: A parameter that determines how much of the snip score is preserved from last pruning step. + beta: A parameter that determines how much of the snip score is updated at the current step. + + Attributes: + scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. + """ + + def __init__(self, modules, config): + """Initiliaze a block_mask pruning criterion.""" + super(RetrainFreeCriterion, self).__init__(modules, config) + assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" + self.collected_grads = {} + for key in self.modules.keys(): + for name, param in self.modules[key].named_parameters(): + if 'block_mask' in name: + continue + param.requires_grad_(False) # only for retrain-free criterion + + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask = self.modules[key].block_mask + self.scores[key] = torch.zeros(mask.shape).to(mask.device) + self.collected_grads[key] = [] + + def on_before_optimizer_step(self): + """Calculate and store the pruning scores based on snip_momentum_block criterion.""" + with torch.no_grad(): + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask_grad = self.modules[key].block_mask.grad.clone() + self.collected_grads[key].append(mask_grad) + self.scores[key] += mask_grad.pow(2) + diff --git a/neural_compressor/compression/pruner/patterns.py b/neural_compressor/compression/pruner/patterns.py index 985a567b217..02d0602e84c 100644 --- a/neural_compressor/compression/pruner/patterns.py +++ b/neural_compressor/compression/pruner/patterns.py @@ -16,14 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import torch -from .utils import logger +from .utils import torch, logger from collections import namedtuple - PATTERNS = {} - def register_pattern(name): """Class decorator used to register a Pattern subclass to the registry. @@ -76,16 +73,16 @@ class BasePattern: """Pruning Pattern. It defines the basic pruning unit and how this unit will be pruned during pruning, e.g. 4x1, 2:4. - + Args: config: A config dict object that contains the pattern information. modules: Torch neural network modules to be pruned with the pattern. Attributes: - pattern: A config dict object that includes information of the pattern. + pattern: A config dict object that includes information of the pattern. is_global: A bool determining whether the pruning takes global pruning option. Global pruning means that pruning scores by a pruning criterion are evaluated in all layers. - Local pruning, by contrast, means that pruning scores by the pruning criterion are evaluated + Local pruning, by contrast, means that pruning scores by the pruning criterion are evaluated in every layer individually. keep_mask_layers:A dict that includes the layers whose mask will not be updated. invalid_layers: The layers whose shapes don't fit the pattern. @@ -107,12 +104,13 @@ def __init__(self, config, modules): self.max_sparsity_ratio_per_op = self.config['max_sparsity_ratio_per_op'] self.min_sparsity_ratio_per_op = self.config['min_sparsity_ratio_per_op'] self.target_sparsity_ratio = self.config['target_sparsity'] + self.block = bool('block' in self.config['pruning_type'] or 'free' in self.config['pruning_type']) # Not using deterministic_algorithms for all examples torch.use_deterministic_algorithms(False) def reduce_tensor(self, data, dim): """Reduce the data along the given dimension. - + Args: data: The input data. dim: The reduced axis. @@ -139,7 +137,7 @@ def get_masks(self, scores, target_sparsity_ratio, pre_masks): pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at last pruning step. Returns: - A dict with the identical size as pre_masks and its 0/1 values are updated. + A dict with the identical size as pre_masks and its 0/1 values are updated. 1 means unpruned and 0 means pruned. """ if self.is_global: @@ -153,14 +151,14 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks): def get_masks_local(self, scores, target_sparsity_ratio, pre_masks): """Generate the weight masks for local pruning. - + Args: scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. target_sparsity_ratio: A float. After pruning, the sparsity of the modules will reach this value. pre_masks: A dict{"layer_name": Tensor}. The previous masks generated at the last pruning step. Returns: - A dict with the identical size as pre_masks and its 0/1 values are updated. + A dict with the identical size as pre_masks and its 0/1 values are updated. 1 means unpruned and 0 means pruned. """ masks = {} @@ -196,9 +194,9 @@ def get_single_mask_per_target_ratio(self, score, exact_sparsity_ratio): def get_block_size_dict(self, data): """Get pattern size for each module. - + This is mainly for per-channel pruning when each module has different pruning size. - + Args: data: the input data. @@ -222,12 +220,12 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): for key in pre_masks.keys(): pre_mask = pre_masks[key] zero_cnt += torch.sum(pre_mask == 0.0).data.item() - total_cnt += pre_masks[key].numel() ##FIXME + total_cnt += pre_mask.numel() ##FIXME if return_dict: return {"sparsity_ratio": float(zero_cnt) / total_cnt, "zero_cnt": zero_cnt, "total_cnt": total_cnt} else: return float(zero_cnt) / total_cnt - + def get_pattern_lock_masks(self, modules): """Obtain masks from original weight map according the pattern and weights' zero positions. @@ -256,7 +254,7 @@ def get_reduced_masks_from_data(self, data, key): def update_residual_cnt(self, masks, target_sparsity_ratio): """Update the number of parameters yet to be pruned. - + Args: masks: the current pruning mask. target_sparsity_ratio: A float representing the final sparsity of the modules. @@ -270,12 +268,12 @@ def update_residual_cnt(self, masks, target_sparsity_ratio): if self.keep_mask_layers.get(key, False): zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"] to_prune_cnt -= zero_cnt - + return to_prune_cnt def get_sparsity_ratio_each_layer(self, masks): """Calculate the sparsity ratio of each layer. - + Args: masks: The current weight masks. @@ -289,7 +287,7 @@ def get_sparsity_ratio_each_layer(self, masks): for key in masks.keys(): if key in self.invalid_layers: continue - reduced_mask = self.get_reduced_masks_from_data(masks[key], key) + reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key) zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item())) total_cnt = int(reduced_mask.numel()) sparsity_ratio = float(zero_cnt) / total_cnt @@ -304,7 +302,7 @@ def adjust_ratio(self, masks: dict, layer_name: str, key_new_sparsity: SparsityI max_sparsity_ratio: float, min_sparsity_ratio: float, \ final_target_sparsity_ratio: float): """Adjust the sparsity of a layer based on threshold. - + Args: masks: The weight masks. layer_name: The layer to be examined. @@ -373,13 +371,13 @@ def adjust_ratio(self, masks: dict, layer_name: str, key_new_sparsity: SparsityI @register_pattern('NxM') class PatternNxM(BasePattern): """Pruning Pattern. - + A Pattern class derived from BasePattern. In this pattern, the weights in a NxM block will be pruned or kept during one pruning step. - + Args: config: A config dict object that contains the pattern information. - + Attributes: block_size: A list of two integers representing the height and width of the block. Please note that the vertical direction of a Linear layer's weight refers to the output channel. @@ -405,29 +403,29 @@ def __init__(self, config, modules): def get_block_size_dict(self): """Calulate the zero elements' ration in pre_masks. - + Args: data: Dict{"layer_name": Tensor} that stores weights or scores. - + Returns: A dict. Dict{"layer_name": [block_size_1, block_size_2]} containing block shapes of each layer. In channel-wise pruning different layers can have different pruning patterns. """ - data = self.modules + datas = self.modules block_sizes_dict = {} - if self.N == "channel" or self.M == "channel": - for key in data.keys(): - if isinstance(data[key], torch.nn.Module): - shape = data[key].weight.shape - else: - shape = data[key].shape - if self.N == "channel": - block_sizes_dict[key] = [shape[0], 1] - else: - block_sizes_dict[key] = [1, shape[1]] - return block_sizes_dict - for key in data.keys(): + for key in datas.keys(): block_sizes_dict[key] = self.block_size + if not (self.N == "channel" or self.M == "channel"): + continue + if isinstance(datas[key], torch.nn.Module): + shape = datas[key].weight.shape + else: + shape = datas[key].shape + if self.N == "channel": # support "channelxM" format + block_sizes_dict[key] = [shape[0], self.block_size[1]] + if self.M == "channel": + block_sizes_dict[key] = [self.block_size[0], shape[1]] + return block_sizes_dict def check_layer_validity(self): @@ -445,11 +443,11 @@ def check_layer_validity(self): def get_reduced_masks_from_data(self, data, key): """Obtain the unpruned weights and reshape according to the block_size. - + Args: data: Input. key: The layer name. - + Returns: The unpruned weights. """ @@ -478,7 +476,7 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): for key in pre_masks.keys(): if key in self.invalid_layers: continue - reduced_mask = self.get_reduced_masks_from_data(pre_masks[key], key) + reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key) zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item())) total_cnt += int(reduced_mask.numel()) if total_cnt == 0: @@ -492,11 +490,11 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): def get_sparsity_ratio_progressive(self, pre_masks, return_dict=False): """Calculate the sparsity ratio of each layer. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. return_dict: A bool determining whether to return more information like zero_cnt and total_cnt. - + Returns: A float representing the zero elements' ratio in pre_masks. """ @@ -515,7 +513,7 @@ def _reshape_orig_to_2dims(self, data): Args: data: Input. - + Returns: Reshaped data. """ @@ -531,7 +529,7 @@ def _reshape_2dims_to_orig(self, data, orig_shape): Args: data: Input. orig_shape: Target shape. - + Returns: Reshaped data. """ @@ -542,12 +540,12 @@ def _reshape_2dims_to_orig(self, data, orig_shape): return data def reshape_orig_to_pattern(self, data, key): - """Reshape the data(s1,s2) to [s1/N,N,s2,s2/M]. + """Reshape the data(s1,s2) to [s1/N,N,s2/M,M]. Args: data: The input. key: The layer name. - + Returns: Reshaped input tensor. """ @@ -566,7 +564,7 @@ def reshape_reduced_to_orig(self, data, key, orig_shape): data: Input. key: The layer name. orig_shape: The original shape of the layer. - + Returns: Data of its original shape. """ @@ -577,7 +575,7 @@ def reshape_reduced_to_orig(self, data, key, orig_shape): def reduce_scores(self, scores): """Recalculate the pruning scores after reducing the data. - + Args: scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. @@ -603,7 +601,8 @@ def get_mask_per_threshold(self, score, threshold, block_size): zero = torch.tensor([0.]).to(score.device) one = torch.tensor([1.]).to(score.device) mask = torch.where(score <= threshold, zero, one) - mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) + if not self.block: + mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) return mask def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, @@ -612,14 +611,14 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, Gather all layer's scores together and calculate a common threshold. This threshold will be applied to all layers. - + Args: scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. cur_target_sparsity_ratio: A float representing the model's sparsity after pruning. pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at the last pruning step. max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach. keep_pre_masks: A bool representing if the masks should remain unchanged. - + Returns: A dict with the identical size as pre_masks and its 0/1 values are updated. 1 means unpruned and 0 means pruned. @@ -627,19 +626,23 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, ##keep the masks if the layer exceed max sparsity ratio masks = pre_masks - k_blockwise = self.update_residual_cnt(masks, cur_target_sparsity_ratio) if k_blockwise <= 0: return masks - new_scores = self.reduce_scores(scores) - global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()]) + new_scores = scores if self.block else self.reduce_scores(scores) + not_exceed_layers = [] residual_k = k_blockwise - not_exceed_layers = [key for key in new_scores.keys()] if self.min_sparsity_ratio_per_op > 0: sparsity_infos_perlayer, _ = self.get_sparsity_ratio_each_layer(masks) while True: + new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)] + if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0: + break + not_exceed_layers = new_not_exceed_layers + global_scores = torch.cat([torch.flatten(new_scores[key]) for key in not_exceed_layers]) threshold, _ = torch.kthvalue(global_scores, residual_k) + for key in not_exceed_layers: block_size = self.block_size[key] score = new_scores[key] @@ -657,7 +660,8 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, # uptade status self.keep_mask_layers[key] = True masks[key] = self.get_single_mask_per_target_ratio(new_scores[key], adjust_ratio) - masks[key] = masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1) + if not self.block: + masks[key] = masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1) if keep_exact_sparsity_ratio: zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"] residual_k -= zero_cnt @@ -665,11 +669,6 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, masks[key] = mask if not keep_exact_sparsity_ratio: break - new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)] - if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0: - break - not_exceed_layers = new_not_exceed_layers - global_scores = torch.cat([torch.flatten(new_scores[key]) for key in not_exceed_layers]) for key in masks.keys(): if key in self.invalid_layers: @@ -685,10 +684,10 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, def get_pattern_lock_masks(self, modules): """Obtain masks from original weight map by masking the zero-valued weights. - + Args: modules: A dict{"layer_name": Tensor} that stores weights. - + Returns: A dict with the identical size as modules, containing pattern lock masks. """ @@ -704,11 +703,49 @@ def get_pattern_lock_masks(self, modules): mask = self.reshape_reduced_to_orig(reduced_mask, key, ori_shape) pattern_lock_masks[key] = mask return pattern_lock_masks + + def get_block_masks(self, modules): + """Register the block mask parameters and get the mask gradients. + + Args: + modules: A dict{"layer_name": Tensor} that stores weights. + + Returns: + A dict containing block masks. + """ + masks = {} + for key in modules.keys(): + if key in self.invalid_layers: + continue # No corresponding block mask, skip. + module = modules[key] + weight = module.weight + if type(module).__name__ not in ["Linear"]: + logger.warning(f"Currently only support Linear block mask pruning," \ + f"{type(module).__name__} won't be pruned.") + continue + block_mask = torch.nn.Parameter(self.get_reduced_masks_from_data(weight, key).to(dtype=weight.dtype)) + module.register_parameter("block_mask", block_mask) + masks[key] = modules[key].block_mask.data + return masks + + def mask_block_weights(self): + """Achieve weight pruning by multiplying the reshaped weights and block masks.""" + for key in self.modules.keys(): + if key in self.invalid_layers: + continue + module = self.modules[key] + block_size = self.block_size[key] + org_shape = module.weight.shape + mask = module.block_mask.data.repeat_interleave(\ + block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1).to(module.weight.device) + reshaped_weight = self._reshape_orig_to_2dims(module.weight.data) * mask + module.weight.data = self._reshape_2dims_to_orig(reshaped_weight, org_shape) + # ---------------progressive related-------------------- def count_new_masked_cnts(self, new_added_masks): """Count the number of elements to be masked. - + Args: new_added_masks: A dict {"layer_name": Tensor} that stores the added masks. @@ -729,7 +766,7 @@ def update_new_added_masks(self, pre_masks, cur_masks): Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. - + Returns: A dict {"layer_name": Tensor} that stores the added masks. """ @@ -746,14 +783,14 @@ def update_new_added_masks(self, pre_masks, cur_masks): def update_progressive_masks(self, pre_masks, cur_masks, scores, progressive_step, progressive_configs): """Generate the progressive masks. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. progressive_step: An integer representing the number of current step in progressive pruning. progressive_configs: A dict that stores configurations of progressive pruning. - + Returns: A dict{"layer_name": Tensor} that stores the masks generated in progressive pruning. """ @@ -767,7 +804,7 @@ def update_progressive_masks(self, pre_masks, cur_masks, scores, progressive_ste def update_progressive_masks_linear(self, pre_masks, cur_masks, progressive_step, progressive_configs): """Generate the progressive masks along the block's larger dimension. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. @@ -806,7 +843,7 @@ def update_progressive_masks_linear(self, pre_masks, cur_masks, progressive_step def update_progressive_masks_scores(self, pre_masks, cur_masks, scores, progressive_step, progressive_configs): """Generate the progressive masks based on scores. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. @@ -853,14 +890,14 @@ def update_progressive_masks_scores(self, pre_masks, cur_masks, scores, progress def update_progressive_masks_local(self, pre_masks, cur_masks, scores, progressive_step, progressive_configs): """Generate progressive masks in a local pruning domain. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. progressive_step: An integer representing the number of current step in progressive pruning. progressive_configs: A dict that stores configurations of progressive pruning. - + Returns: A dict{"layer_name": Tensor} that stores the masks generated in progressive pruning. """ @@ -877,14 +914,14 @@ def update_progressive_masks_local(self, pre_masks, cur_masks, scores, progressi def update_progressive_masks_global(self, pre_masks, cur_masks, scores, progressive_step, progressive_configs): """Gather all layer's scores to obtain a threshold that would be applied to all layers. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. cur_masks: Dict{"layer_name": Tensor} that stores the current masks. scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. progressive_step: An integer representing the number of current step in progressive pruning. progressive_configs: A dict that stores configurations of progressive pruning. - + Returns: A dict{"layer_name": Tensor} that stores the masks generated in progressive pruning. """ @@ -926,14 +963,14 @@ def update_progressive_masks_global(self, pre_masks, cur_masks, scores, progress @register_pattern('N:M') class PatternNInM(BasePattern): """Pruning Pattern. - + A Pattern class derived from Pattern. In this pattern, N out of every M continuous weights will be pruned. For more info of this pattern, please refer to : https://github.com/intel/neural-compressor/blob/master/docs/sparsity.md - + Args: config: A config dict object that contains the pattern information. - + Attributes: N: The number of elements to be pruned in a weight sequence. M: The size of the weight sequence. @@ -949,7 +986,7 @@ def __init__(self, config, modules): def check_layer_validity(self, datas: dict, block_size: tuple): """Check if a layer is valid for this block_size. - + Args: datas: A dict object containing the weights for all layers. block_size: A tuple representing the size of the pattern block. @@ -965,7 +1002,7 @@ def check_layer_validity(self, datas: dict, block_size: tuple): def get_reduced_masks_from_data(self, data, key): """Obtain the unpruned weights and reshape according to the block_size. - + Args: data: Input. key: The layer name. @@ -985,12 +1022,12 @@ def get_reduced_masks_from_data(self, data, key): def get_least_ninm_mask_from_data(self, score): """Generate the least N scores in M. - + Args: score: the pruning scores of weights. Returns: - A dict with the identical size as pre_masks and its 0/1 values are updated. + A dict with the identical size as pre_masks and its 0/1 values are updated. 1 means unpruned and 0 means pruned. """ current_score = score @@ -1016,11 +1053,11 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): """Please note that the zero cnt and total cnt are all block_wise for supporting channel-wise pruning. The return sparsity ratio is elementwised. - + Args: pre_masks: Dict{"layer_name": Tensor} that stores the masks generated after the last pruning step. return_dict: A bool determining whether to return more information like zero_cnt and total_cnt. - + Returns: An elementwise sparisty ratio. """ @@ -1046,7 +1083,7 @@ def _reshape_orig_to_2dims(self, data): Args: data: Input. - + Returns: Reshaped data. """ @@ -1060,7 +1097,7 @@ def _reshape_2dims_to_orig(self, data, orig_shape): Args: data: Input. - + Returns: Reshaped data. """ @@ -1071,11 +1108,11 @@ def _reshape_2dims_to_orig(self, data, orig_shape): def reshape_orig_to_pattern(self, data, key): """Reshape the data based on the pruning pattern. - + Args: data: Input. key: layer name. - + Returns: Reshaped data. """ @@ -1087,12 +1124,12 @@ def reshape_orig_to_pattern(self, data, key): def reshape_reduced_to_orig(self, data, key, orig_shape): """Reshape the reduced data to its original shape. - + Args: data: Input. key: The layer name. orig_shape: The original shape of the layer. - + Returns: Data of its original shape. """ @@ -1101,7 +1138,7 @@ def reshape_reduced_to_orig(self, data, key, orig_shape): def reduce_scores(self, scores): """Calculate the pruning scores after reducing the data and obtain the least N scores in M. - + Args: scores: Pruning scores of weights. @@ -1139,7 +1176,7 @@ def get_ele_mask_per_threshold(self, score, threshold, block_size, least_ninm_ma threshold: A float used to determine whether to prune a weight. block_size: A list of two integers representing the height and width of the block. least_m_in_m_masks: A tensor representing the least N scores in M. - + Returns: mask: The elementwise pruning mask. """ @@ -1155,18 +1192,18 @@ def get_ele_mask_per_threshold(self, score, threshold, block_size, least_ninm_ma def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, keep_exact_sparsity_ratio=True): """Generate masks for layers. - + Gather all layer's scores together and calculate a common threshold. This threshold will be applied for all layers. - + Args: scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. target_sparsity_ratio: A float representing the model's final sparsity. pre_masks: A dict{"layer_name": Tensor} representing the masks generated after the last pruning step. max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach. - + Returns: - A dict with the identical size as pre_masks and its 0/1 values are updated. + A dict with the identical size as pre_masks and its 0/1 values are updated. 1 means unpruned and 0 means pruned. """ masks = pre_masks @@ -1231,10 +1268,10 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, def get_pattern_lock_masks(self, modules): """Obtain masks from original weight map, by masking where weights' are zero. - + Args: modules: A dict{"layer_name": Tensor} that stores weights. - + Returns: A dict with the identical size as modules, containing pattern lock masks. """ @@ -1251,3 +1288,4 @@ def get_pattern_lock_masks(self, modules): pattern_lock_masks[key] = mask return pattern_lock_masks + diff --git a/neural_compressor/compression/pruner/pruners.py b/neural_compressor/compression/pruner/pruners.py index 939006e7ffa..788e87cce6a 100644 --- a/neural_compressor/compression/pruner/pruners.py +++ b/neural_compressor/compression/pruner/pruners.py @@ -15,8 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import copy -from .utils import torch +from .utils import torch, F +from functools import partial from .patterns import get_pattern from .schedulers import get_scheduler from .criteria import get_criterion, CRITERIA @@ -82,7 +84,10 @@ def get_pruner(config, modules): if name in CRITERIA: if config["progressive"] == False: config['criterion_type'] = name - name = "basic" ##return the basic pruner + if "block" in name or "free" in name: + assert ":" not in config["pattern"], f"{name} pruner type does not support {config['pattern']} pattern." + else : + name = "basic" ##return the basic pruner else: config['criterion_type'] = name name = "progressive" ## return the progressive pruner @@ -198,7 +203,7 @@ def on_epoch_begin(self, epoch): def mask_weights(self): """Apply masks to corresponding modules' weights. - + Weights are multipled with masks. This is the formal pruning process. """ with torch.no_grad(): @@ -208,8 +213,8 @@ def mask_weights(self): def mask_weights_general(self, input_masks): """Apply input masks to corresponding modules' weights. - - Weights are multipled with input_masks. + + Weights are multipled with input_masks. Args: input_masks: A dict {"module_name": Tensor} that stores the masks for modules' weights. @@ -244,7 +249,7 @@ def on_before_optimizer_step(self): def on_after_optimizer_step(self): """Implement after optimizer.step(). - + Prune the model after optimization. """ self.mask_weights() @@ -272,7 +277,7 @@ def check_is_pruned_step(self, step): Args: step: an integer representing the number of current step. - Returns: + Returns: A Boolean. """ if step < self.start_step or step > self.end_step: @@ -281,6 +286,30 @@ def check_is_pruned_step(self, step): return True return False + def rewrite_forward(self): + """Rewrite forward to implement block mask operation""" + def forward(self, input): + block_size = [self.weight.shape[0]//self.block_mask.shape[0], \ + self.weight.shape[1]//self.block_mask.shape[1]] + mask = self.block_mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(\ + block_size[1], dim=-1).to(self.weight.device) + return F.linear(input, self.weight*mask, self.bias) + + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + module = self.modules[key] + module.forward = partial(forward, module) + + def recover_forward(self): + """Restore the forward format at the end of pruning""" + with torch.no_grad(): + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + module = self.modules[key] + module.forward = partial(torch.nn.Linear.forward, module) + @register_pruner("basic") class BasicPruner(BasePruner): @@ -407,19 +436,293 @@ def update_masks(self, local_step): def on_after_optimizer_step(self): """Implement after optimizer.step(). - + Prune the model after optimization. """ self.mask_weights() self.global_step += 1 +@register_pruner('block_mask') +class BlockMaskPruner(BasePruner): + """Pruning Pruner. + + The class which executes pruning process. + 1. Defines pruning functions called at step begin/end, before/after optimize and epoch begin/end. + 2. Defines the pruning criterion. + 3. Obtain block masks and its grads. + + Args: + modules: A dict {"module_name": Tensor} that stores the pruning modules' weights. + config: A config dict object that contains the pruner information. + + Attributes: + pattern: A Pattern object that defines pruning weights' arrangements within space. + criterion: A Criterion Object that defines which weights are to be pruned + scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds. + reg: A Reg object that defines regulization terms. + """ + def __init__(self, config, modules): + """Initialize.""" + super(BlockMaskPruner, self).__init__(config, modules) + + def _init(self): + """Initialize.""" + self.pattern = get_pattern(self.config, self.modules) + self.masks = self.pattern.get_block_masks(self.modules) + self.rewrite_forward() + self.scheduler = get_scheduler(self.config) + self.criterion = get_criterion(self.config, self.modules) + self.reg = get_reg(self.config, self.modules, self.pattern) + + if "channel" not in self.pattern.pattern: + logger.info("Enabling channel-wise pattern would be a better choice.") + + # def on_step_begin(self, local_step): + # """Implement at the start of each step. + + # Update the masks at a given local_step. + # """ + # self.update_masks(local_step) + + def update_masks(self, local_step): + """Update the masks at a given local step.""" + if self.global_step == self.start_step: + if self.config['lock_init_sparsity']: + self.init_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks) + self.current_sparsity_ratio = self.init_sparsity_ratio + + if not self.check_is_pruned_step(self.global_step): + return + + if self.current_sparsity_ratio > self.target_sparsity_ratio: + return + + self.criterion.on_step_begin() + current_target_sparsity_ratio = self.scheduler.update_sparsity_ratio(self.target_sparsity_ratio, + self.completed_pruned_cnt, + self.total_prune_cnt, self.masks, + self.init_sparsity_ratio) + logger.info(f"current target ratio is {current_target_sparsity_ratio}") + + self.completed_pruned_cnt += 1 + if self.criterion.scores == {}: + return + self.masks = self.pattern.get_masks(self.criterion.scores, current_target_sparsity_ratio, self.masks) + self.update_block_masks(self.masks) + self.mask_weights() + + self.current_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks) + logger.info(f"current sparsity ratio is {self.current_sparsity_ratio}") + + if (self.end_step-self.global_step) / self.pruning_frequency < 1: + self.recover_forward() + + def on_before_optimizer_step(self): + """Implement before optimizer.step().""" + self.reg.on_before_optimizer_step() + self.criterion.on_before_optimizer_step() + + def on_after_optimizer_step(self): + """Prune the model after optimization.""" + ##the order of the following four lines can't not be exchanged + if self.global_step >= self.start_step and self.global_step <= self.end_step: + self.reg.on_after_optimizer_step() + self.zero_mask_grad() + self.mask_weights() + self.global_step += 1 + + def mask_weights(self): + """Apply block masks to corresponding modules' weights. + + Weights are multipled with masks. This is the formal pruning process. + """ + with torch.no_grad(): + self.pattern.mask_block_weights() + + def update_block_masks(self, masks): + """Update the block mask parameters.""" + with torch.no_grad(): + for key in self.masks.keys(): + module = self.modules[key] + module.block_mask.data = masks[key].data + + def zero_mask_grad(self): + with torch.no_grad(): + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask = self.modules[key].block_mask + if mask.grad is not None: + if mask.grad.grad_fn is not None: + mask.grad.detach_() + else: + mask.grad.requires_grad_(False) + mask.grad.zero_() + + +@register_pruner('retrain_free') +class RetrainFreePruner(BasePruner): + """Pruning Pruner. + The retrain_free pruner_class is derived from BasePruner. + This pruner references the mask search and mask rearrangement strategies in fast retraining free. + RetrainFreePruner supports one-shot pruning (same effect as fast retraining free) and iterative pruning. + Please refer to A Fast Post-Training Pruning Framework for Transformers + (https://arxiv.org/abs/2204.09656) + + 1. Defines pruning functions called at step begin/end, before/after optimize and epoch begin/end. + 2. Defines the pruning criterion and fixed weight parameters. + 3. Obtain block masks and its grads. + 4. Rearrange block masks. + + Args: + modules: A dict {"module_name": Tensor} that stores the pruning modules' weights. + config: A config dict object that contains the pruner information. + + Attributes: + pattern: A Pattern object that defines pruning weights' arrangements within space. + criterion: A Criterion Object that defines which weights are to be pruned + scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds. + reg: A Reg object that defines regulization terms. + """ + def __init__(self, config, modules): + """Initialize.""" + super(RetrainFreePruner, self).__init__(config, modules) + + def _init(self): + """Initialize.""" + self.pattern = get_pattern(self.config, self.modules) + self.masks = self.pattern.get_block_masks(self.modules) + self.rewrite_forward() + self.scheduler = get_scheduler(self.config) + self.criterion = get_criterion(self.config, self.modules) + self.reg = get_reg(self.config, self.modules, self.pattern) + + logger.warning("Retrain-free pruner fixed the weights, please DO NOT turn on gradient update.") + assert "channel" in self.pattern.pattern, \ + "retrain-free pruner only supports large patterns like channel-wise pruning." + + # def on_step_begin(self, local_step): + # """Implement at the start of each step. + + # Update the masks at a given local_step. + # """ + # self.update_masks(local_step) + + def update_masks(self, local_step): + """Update the masks at a given local step.""" + if self.global_step == self.start_step: + if self.config['lock_init_sparsity']: + self.init_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks) + self.current_sparsity_ratio = self.init_sparsity_ratio + + if not self.check_is_pruned_step(self.global_step): + return + + if self.current_sparsity_ratio > self.target_sparsity_ratio: + return + + self.criterion.on_step_begin() + current_target_sparsity_ratio = self.scheduler.update_sparsity_ratio(self.target_sparsity_ratio, + self.completed_pruned_cnt, + self.total_prune_cnt, self.masks, + self.init_sparsity_ratio) + logger.info(f"current target ratio is {current_target_sparsity_ratio}") + + self.completed_pruned_cnt += 1 + if self.criterion.scores == {}: + return + self.masks = self.pattern.get_masks(self.criterion.scores, current_target_sparsity_ratio, self.masks) + self.rearrange_masks(self.masks) + self.update_block_masks(self.masks) + # support iterative rearrangement + if (self.end_step-self.global_step) / self.pruning_frequency < 1: + self.mask_weights() + logger.info(f"mask weights at end_step: {self.global_step}") + self.recover_forward() + + self.current_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks) + logger.info(f"current sparsity ratio is {self.current_sparsity_ratio}") + + def on_before_optimizer_step(self): + """Implement before optimizer.step().""" + self.reg.on_before_optimizer_step() + self.criterion.on_before_optimizer_step() + + def on_after_optimizer_step(self): + """Prune the model after optimization.""" + ##the order of the following four lines can't not be exchanged + if self.global_step >= self.start_step and self.global_step <= self.end_step: + self.reg.on_after_optimizer_step() + self.zero_mask_grad() + # self.mask_weights() #done on update_masks + self.global_step += 1 + + def mask_weights(self): + """Apply block masks to corresponding modules' weights. + + Weights are multipled with masks. This is the formal pruning process. + """ + with torch.no_grad(): + self.pattern.mask_block_weights() + + def update_block_masks(self, masks): + """Update the block mask parameters.""" + with torch.no_grad(): + for key in self.masks.keys(): + module = self.modules[key] + module.block_mask.data = masks[key].data + + def rearrange_masks(self, masks): + """Rearrange the masks of each layer with constant sparsity.""" + with torch.no_grad(): + new_masks = {} + for key in masks.keys(): + block_mask = masks[key] + num_pruned = torch.sum(block_mask == 0.0).data.item() + grads = torch.stack(self.criterion.collected_grads[key], dim=0).squeeze() + if not num_pruned: + new_masks[key] = block_mask + continue + grads = grads.permute(1, 0).contiguous() + grads_sq = grads.pow(2).sum(dim=1) + _, indicies = grads_sq.sort(descending=False) + indicies = indicies.tolist() + masked_indicies = indicies[:num_pruned] + for index in indicies[num_pruned:]: + masked_indicies.append(index) + grad_vectors = grads[masked_indicies] + grad_sum = grad_vectors.sum(dim=0) + complement = grad_sum - grad_vectors + grad_sum_length = complement.pow(2).sum(dim=1) + removed = grad_sum_length.argmin() + del masked_indicies[removed] + + new_masks[key] = torch.ones(len(indicies)).to(block_mask.device) + new_masks[key][masked_indicies] = 0 + new_masks[key] = new_masks[key] * torch.ones_like(block_mask).to(block_mask.device) + self.masks = new_masks + + def zero_mask_grad(self): + with torch.no_grad(): + for key in self.modules.keys(): + if not hasattr(self.modules[key], 'block_mask'): + continue # No corresponding block mask, skip. + mask = self.modules[key].block_mask + if mask.grad is not None: + if mask.grad.grad_fn is not None: + mask.grad.detach_() + else: + mask.grad.requires_grad_(False) + mask.grad.zero_() + + @register_pruner('progressive') class ProgressivePruner(BasicPruner): """Pruning Pruner. A Pruner class derived from BasePruner. In this pruner, mask interpolation will be applied. - Mask interpolation is a fine-grained improvement for NxM structured pruning by adding interval + Mask interpolation is a fine-grained improvement for NxM structured pruning by adding interval masks between masks of two pruning steps. Args: @@ -496,7 +799,7 @@ def check_progressive_validity(self): if self.use_global: # when global progressive is applied, linear type is contradict. raise NotImplementedError("Global progressive pruning do not support linear pattern") - # When linear, progressive_step should not meet a indivisible + # When linear, progressive_step should not meet a indivisible for key in self.pattern.block_size.keys(): block_size = self.pattern.block_size[key] progressive_direction = max(block_size) @@ -515,11 +818,11 @@ def check_progressive_validity(self): def check_is_pruned_progressive_step(self, step): """Check if a progressive pruning process should be performed at the current step. - + Args: step: an integer representing the number of current step. - - Returns: + + Returns: A Boolean. """ # used in progressive pruning @@ -589,7 +892,7 @@ def update_masks_progressive(self, local_step): def on_step_begin(self, local_step): """Update the masks at a given local_step. - + Implement at the start of each step. """ if self.handled_global_step == self.global_step: @@ -624,3 +927,5 @@ def print_progressive_sparsity(self): """Output the progressive sparsity.""" cur_sp = self.pattern.get_sparsity_ratio_progressive(self.progressive_masks) logger.info("Step: {} -> Current progressive sparsity: {}".format(self.global_step, cur_sp)) + + diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py index 6bd739dd345..19b569765bc 100644 --- a/neural_compressor/compression/pruner/utils.py +++ b/neural_compressor/compression/pruner/utils.py @@ -28,13 +28,17 @@ from neural_compressor.conf.config import Pruner LazyImport('torch.nn') torch = LazyImport('torch') + F = LazyImport('torch.nn.functional') + except: import torch + import torch.nn.functional as F from .dot_dict import DotDict ##TODO import logging logger = logging.getLogger(__name__) from .schema_check import PrunerV2 + class WeightPruningConfig: """Similiar to torch optimizer's interface.""" @@ -96,7 +100,11 @@ def get_sparsity_ratio(pruners, model): cnt += modules[key].weight.numel() pattern_sparsity_cnt += int(cnt * sparsity_ratio) for key in pruner.masks.keys(): - element_sparsity_cnt += torch.sum(pruner.masks[key] == 0).data.item() + block_num = 1 + if pruner.pattern.block: + block_size = pruner.pattern.block_size[key] + block_num = block_size[0] * block_size[1] + element_sparsity_cnt += torch.sum(pruner.masks[key] == 0).data.item() * block_num linear_conv_cnt = 0 param_cnt = 0 @@ -176,7 +184,7 @@ def check_config(prune_config): max_ratio = float(N) / M if prune_config['pruning_type']!="pattern_lock": assert prune_config['target_sparsity'] <= max_ratio, \ - "in N:M pattern, the max sparsity is N/M={}".format(max_ratio) + "in N:M pattern, the max sparsity is N/M={}".format(max_ratio) prune_config['max_sparsity_ratio_per_op'] = min(max_ratio, prune_config['max_sparsity_ratio_per_op']) if prune_config['reg_coeff'] != None: prune_config['reg_coeff'] = float(prune_config['reg_coeff']) @@ -311,7 +319,7 @@ def check_key_validity_prunerv2(template_config, usr_cfg_dict): for user_key, user_value in usr_cfg_dict.pruner_config.items(): if user_key not in template_config.keys(): logger.warning(f"{user_key} is not supported for config") - + # multi pruners if isinstance(user_config, list): for obj in user_config: @@ -319,7 +327,7 @@ def check_key_validity_prunerv2(template_config, usr_cfg_dict): check_key_validity_dict(template_config, obj) elif isinstance(obj, PrunerV2): check_key_validity_prunerv2(template_config, obj) - + # single pruner, weightconfig or yaml elif isinstance(user_config, dict): check_key_validity_dict(template_config, user_config) @@ -329,15 +337,15 @@ def check_key_validity_prunerv2(template_config, usr_cfg_dict): def process_and_check_config(val): """Process and check configurations. - - Args: + + Args: val: A dict that contains the layer-specific pruning configurations. """ default_global_config = {'target_sparsity': 0.9, 'pruning_type': 'snip_momentum', 'pattern': '4x1', 'op_names': [], 'excluded_op_names': [], 'start_step': 0, 'end_step': 0, 'pruning_scope': 'global', 'pruning_frequency': 1, 'min_sparsity_ratio_per_op': 0.0, 'max_sparsity_ratio_per_op': 0.98, - 'sparsity_decay_type': 'exp', + 'sparsity_decay_type': 'exp', "criterion_type": "snip_momentum", 'pruning_op_types': ['Conv', 'Linear'], } default_local_config = {'resume_from_pruned_checkpoint': False, 'reg_type': None, @@ -398,7 +406,7 @@ def process_config(config): def parse_to_prune(config, model): """Keep target pruned layers. - + Args: config: A string representing the path to the configuration file. model: The model to be pruned. @@ -431,7 +439,7 @@ def parse_to_prune(config, model): def generate_pruner_config(info): """Generate pruner config object from prune information. - + Args: info: A dotdict that saves prune information. @@ -446,6 +454,7 @@ def generate_pruner_config(info): update_frequency=info.pruning_frequency, ) + def parse_auto_slim_config(model, ffn2_sparsity = .0, mha_sparsity = .0, **kwargs): """Get model slim pruning configs.""" auto_slim_configs = [] @@ -495,4 +504,5 @@ def generate_mha_pruning_config(model, mha_sparsity, **kwargs): # append kwargs to generated config for item in mha_pruning_config: item.update(kwargs) - return mha_pruning_config \ No newline at end of file + return mha_pruning_config + diff --git a/test/pruning_2.x/test_pruning_block.py b/test/pruning_2.x/test_pruning_block.py new file mode 100644 index 00000000000..155aae99aad --- /dev/null +++ b/test/pruning_2.x/test_pruning_block.py @@ -0,0 +1,88 @@ +import unittest + +import torch +import torchvision +import torch.nn as nn +import sys +sys.path.insert(0, './') +from neural_compressor.data import Datasets +from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader +from neural_compressor import WeightPruningConfig +from neural_compressor.training import prepare_compression + + +class TestPruning(unittest.TestCase): + # model = torchvision.models.resnet18() + model = torchvision.models.vit_b_16() + def test_pruning_basic(self): + local_configs = [ + { + "op_names": ['encoder_layer_1.mlp*'], + "target_sparsity": 0.6, + "pattern": 'channelx2', + "pruning_type": "block_mask", + "pruning_scope": "global", + "criterion_type": "snip_momentum_block", + "criterion_reduce_type": "mean", + "pruning_op_types": "Linear", + }, + { + "op_names": ['encoder_layer_2.mlp*'], + "target_sparsity": 0.5, + "pattern": '32x32', + "pruning_op_types": "Linear", + "pruning_type": "block_mask", + "pruning_scope": "local", + "criterion_type": "snip_momentum_block", + "criterion_reduce_type": "sum", + }, + { + "op_names": ['encoder_layer_3.mlp*'], + 'target_sparsity': 0.4, + 'pattern': 'channelx1', + "pruning_op_types": "Linear", + "pruning_type": "retrain_free", + "pruning_scope": "local", + "pruning_frequency": 2, + } + ] + config = WeightPruningConfig( + local_configs, + target_sparsity=0.8, + start_step=1, + end_step=10 + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) + datasets = Datasets('pytorch') + dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True) + dummy_dataloader = PyTorchDataLoader(dummy_dataset) + + compression_manager = prepare_compression(model=self.model, confs=config) + compression_manager.callbacks.on_train_begin() + for epoch in range(2): + self.model.train() + compression_manager.callbacks.on_epoch_begin(epoch) + local_step = 0 + for image, target in dummy_dataloader: + compression_manager.callbacks.on_step_begin(local_step) + output = self.model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + compression_manager.callbacks.on_before_optimizer_step() + optimizer.step() + compression_manager.callbacks.on_after_optimizer_step() + compression_manager.callbacks.on_step_end() + local_step += 1 + + compression_manager.callbacks.on_epoch_end() + compression_manager.callbacks.on_train_end() + compression_manager.callbacks.on_before_eval() + compression_manager.callbacks.on_after_eval() + + +if __name__ == "__main__": + unittest.main() + diff --git a/test/pruning_2_plus.x/test_pruning_block.py b/test/pruning_2_plus.x/test_pruning_block.py new file mode 100644 index 00000000000..222cadf0d98 --- /dev/null +++ b/test/pruning_2_plus.x/test_pruning_block.py @@ -0,0 +1,75 @@ +import unittest + +import torch +import torchvision +import torch.nn as nn +import sys +sys.path.insert(0, './') +from neural_compressor.data import Datasets +from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader +from neural_compressor import WeightPruningConfig +from neural_compressor.training import prepare_pruning + + +class TestPruning(unittest.TestCase): + # model = torchvision.models.resnet18() + model = torchvision.models.vit_b_16() + def test_pruning_basic(self): + local_configs = [ + { + "op_names": ['encoder_layer_1.mlp*'], + "target_sparsity": 0.6, + "pattern": 'channelx2', + "pruning_type": "block_mask", + "pruning_scope": "global", + "criterion_type": "snip_momentum_block", + "pruning_op_types": "Linear", + }, + { + "op_names": ['encoder_layer_2.mlp*'], + "target_sparsity": 0.5, + "pattern": '32x32', + "pruning_op_types": "Linear", + "pruning_type": "block_mask", + "pruning_scope": "local", + "criterion_type": "snip_momentum_block", + }, + { + "op_names": ['encoder_layer_3.mlp*'], + 'target_sparsity': 0.4, + 'pattern': 'channelx1', + "pruning_op_types": "Linear", + "pruning_type": "retrain_free", + "pruning_scope": "local", + "pruning_frequency": 2, + } + ] + config = WeightPruningConfig( + local_configs, + target_sparsity=0.8, + start_step=1, + end_step=10 + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) + model, optimizer = prepare_pruning(config, self.model, optimizer) + datasets = Datasets('pytorch') + dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True) + dummy_dataloader = PyTorchDataLoader(dummy_dataset) + + for epoch in range(2): + self.model.train() + local_step = 0 + for image, target in dummy_dataloader: + output = self.model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + local_step += 1 + + +if __name__ == "__main__": + unittest.main() +