From 6613cfa9c7b8a06b3b85f35e2cf3ba2663766fd3 Mon Sep 17 00:00:00 2001 From: n1ck-guo <110074967+n1ck-guo@users.noreply.github.com> Date: Thu, 27 Jul 2023 13:11:11 +0800 Subject: [PATCH] add Hyper-parameter Optimization algorithm (#786) * add hpo * add ut&example, update code * fix Signed-off-by: Guo, Heng * add requirement Signed-off-by: Guo, Heng * pylint * add readme * requir * modify readme Signed-off-by: Guo, Heng * modify api * logger * modify example and ut * add hpo api&config Signed-off-by: Guo, Heng * fix * modify api Signed-off-by: Guo, Heng * modify readme Signed-off-by: Guo, Heng * spell Signed-off-by: Guo, Heng * sync readme Signed-off-by: Guo, Heng * modify readme Signed-off-by: Guo, Heng --------- Signed-off-by: Guo, Heng --- .../scripts/codeScan/pylint/pylint.sh | 3 +- .../scripts/codeScan/pyspelling/inc_dict.txt | 4 + docs/source/pruning.md | 9 +- .../text-classification/pruning/hpo/README.md | 29 + .../pruning/hpo/requirements.txt | 12 + .../pruning/hpo/run_glue_no_trainer.py | 670 ++++++++++++++++++ neural_compressor/compression/hpo/__init__.py | 22 + .../compression/hpo/sa_optimizer.py | 123 ++++ .../compression/hpo/search_algorithms.py | 364 ++++++++++ .../compression/hpo/search_space.py | 176 +++++ .../compression/pruner/README.md | 14 +- neural_compressor/config.py | 26 + test/hpo/test_hpo.py | 78 ++ test/requirements.txt | 1 + 14 files changed, 1525 insertions(+), 6 deletions(-) create mode 100644 examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/README.md create mode 100644 examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/requirements.txt create mode 100644 examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py create mode 100644 neural_compressor/compression/hpo/__init__.py create mode 100644 neural_compressor/compression/hpo/sa_optimizer.py create mode 100644 neural_compressor/compression/hpo/search_algorithms.py create mode 100644 neural_compressor/compression/hpo/search_space.py create mode 100644 test/hpo/test_hpo.py diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index 66272c2982a..6167b8db6bf 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -31,7 +31,8 @@ pip install torch==1.12.0 \ onnxruntime_extensions \ tf_slim \ transformers \ - flask==2.1.3 + flask==2.1.3 \ + xgboost if [ "${scan_module}" = "neural_solution" ]; then cd /neural-compressor diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt index d817b141737..75996d6a598 100644 --- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt +++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt @@ -2698,3 +2698,7 @@ Vanhoucke ONNXCommunityMeetup luYBWA pQ +xgb +xgboost +hpo +HPO diff --git a/docs/source/pruning.md b/docs/source/pruning.md index 21e5b34663a..aa9cc4cdf2f 100644 --- a/docs/source/pruning.md +++ b/docs/source/pruning.md @@ -52,8 +52,9 @@ Pruning 4. [Sparse Model Deployment](#sparse-model-deployment) +5. [Pruning With HPO](#pruning-with-hyperparameter-optimization) -5. [Reference](#reference) +6. [Reference](#reference) ## Introduction @@ -104,7 +105,7 @@ Pruning patterns defines the rules of pruned weights' arrangements in space. Int -- Multi-head Attention Pruning (Work in progress) +- Multi-head Attention Pruning Multi-head attention mechanism boosts transformer models' capability of contextual information analysis. However, different heads' contribution to the final output varies. In most situation, a number of heads can be removed without causing accuracy drop. Head pruning can be applied in a wide range of scenes including BERT, GPT as well as other large language models. **We haven't support it in pruning, but we have provided experimental feature in Model Auto Slim**. Please refer to [multi-head attention auto slim examples](https://github.com/intel/neural-compressor/blob/master/examples/pytorch/nlp/huggingface_models/question-answering/model_slim) @@ -386,6 +387,10 @@ Please refer to [pruning examples](../../examples/README.md#Pruning-1) for more Particular hardware/software like [Intel Extension for Transformer](https://github.com/intel/intel-extension-for-transformers) are required to obtain inference speed and footprints' optimization for most sparse models. However, using [model slim](#click) for some special structures can obtain significant inference speed improvements and footprint reduction without the post-pruning deployment. In other words, you can achieve model acceleration directly under your training framework (PyTorch, etc.) +## Pruning with Hyperparameter Optimization +Intel® Neural Compressor currently support grid search, random, bayesian optimization and xgboost search algorithms for pruning with HPO. +For more details, please refer to [HPO document](../../neural_compressor/compression/hpo/README.md) + ## Reference [1] Namhoon Lee, Thalaiyasingam Ajanthan, and Philip Torr. SNIP: Single-shot network pruning based on connection sensitivity. In International Conference on Learning Representations, 2019. diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/README.md b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/README.md new file mode 100644 index 00000000000..3721beacfe8 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/README.md @@ -0,0 +1,29 @@ +Step-by-Step +============ + +This document presents step-by-step instructions for pruning Huggingface models with HPO feature using the Intel® Neural Compressor. + +# Prerequisite +## 1. Environment +Python 3.6 or higher version is recommended. +The dependent packages are listed in `requirements.txt`, please install them as follows, +```shell +cd examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/ +pip install -r requirements.txt +``` +## 2. Prepare Dataset + +The dataset will be downloaded automatically from the datasets Hub. +See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html) + +# Run +To get tuned model and its accuracy: +```shell +python run_glue_no_trainer.py \ + --model_name_or_path M-FAC/bert-mini-finetuned-mrpc \ + --task_name mrpc \ + --per_device_eval_batch_size 18 \ + --per_device_train_batch_size 18 \ + --do_prune + +``` \ No newline at end of file diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/requirements.txt b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/requirements.txt new file mode 100644 index 00000000000..af176cfbaec --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/requirements.txt @@ -0,0 +1,12 @@ +accelerate +datasets +sentencepiece +scipy +scikit-learn +protobuf +torch +evaluate +transformers +tqdm +xgboost + diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py new file mode 100644 index 00000000000..41fd826ae16 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/pruning/hpo/run_glue_no_trainer.py @@ -0,0 +1,670 @@ +# coding=utf-8 +# Copyright (c) 2023 Intel Corporation +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" +import argparse +import logging +import math +import os +import random +from pathlib import Path +import sys + +sys.path.insert(0, './') +import datasets +from datasets import load_dataset, load_metric +from torch.utils.data import DataLoader +import torch +from tqdm.auto import tqdm + +import transformers +from accelerate import Accelerator +from huggingface_hub import Repository +from transformers import ( + AdamW, + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + PretrainedConfig, + SchedulerType, + default_data_collator, + get_scheduler, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils.versions import require_version +from neural_compressor.training import WeightPruningConfig + +logger = logging.getLogger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--task_name", + type=str, + default=None, + help="The name of the glue task to train on.", + choices=list(task_to_keys.keys()), + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--max_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," + " sequences shorter will be padded if `--pad_to_max_lengh` is passed." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--teacher_model_name_or_path", + type=str, + default=None, + help="Path to pretrained teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--distill_loss_weight", + type=float, + default=0.0, + help="distiller loss weight", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--cooldown_epochs", type=int, default=0, help="Cooling epochs after pruning") + + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--sparsity_warm_epochs", type=int, default=0, + help="Number of epochs the network not be purned") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument("--do_prune", action="store_true", help="Whether or not to prune the model") + + parser.add_argument( + "--pruning_pattern", + type=str, default="4x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + parser.add_argument( + "--pruning_type", + type=str, default="snip_momentum", + help="Pruning type determines how should the weights of a neural network are scored and pruned." + ) + args = parser.parse_args() + + # Sanity checks + if args.task_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a task name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def get_loss_one_logit(student_logit, teacher_logit): + t = 2.0 + from torch.nn import functional as F + return F.kl_div( + input=F.log_softmax(student_logit / t, dim=-1), + target=F.softmax(teacher_logit / t, dim=-1), + reduction="batchmean" + ) * (t ** 2) + + +def main(args): + # args = parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.task_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset("glue", args.task_name) + else: + # Loading the dataset from local csv or json file. + data_files = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if args.task_name is not None: + is_regression = args.task_name == "stsb" + if not is_regression: + label_list = raw_datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = raw_datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + model = AutoModelForSequenceClassification.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + if args.distill_loss_weight > 0: + teacher_path = args.teacher_model_name_or_path + if teacher_path is None: + teacher_path = args.model_name_or_path + teacher_model = AutoModelForSequenceClassification.from_pretrained( + teacher_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + + # Preprocessing the datasets + if args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + logger.info( + f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " + "Using it!" + ) + label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} + else: + logger.warning( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif args.task_name is None: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + elif args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} + + padding = "max_length" if args.pad_to_max_length else False + + def preprocess_function(examples): + # Tokenize the texts + texts = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) + + if "label" in examples: + if label_to_id is not None: + # Map labels to IDs (not necessary for GLUE tasks) + result["labels"] = [label_to_id[l] for l in examples["label"]] + else: + # In all cases, rename the column to labels because the model will expect that. + result["labels"] = examples["label"] + return result + + with accelerator.main_process_first(): + processed_datasets = raw_datasets.map( + preprocess_function, + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on dataset", + ) + + train_dataset = processed_datasets["train"] + eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + if args.pad_to_max_length: + # If padding was already done ot max length, we use the default data collator that will just convert everything + # to tensors. + data_collator = default_data_collator + else: + # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of + # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple + # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) + + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + no_decay_classifier = ["bias", "LayerNorm.weight", "classifier"] + + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay_classifier)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if args.do_prune: + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, betas=[0.9, 0.9]) ##changed + else: + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + if args.distill_loss_weight > 0: + teacher_model, model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + teacher_model, model, optimizer, train_dataloader, eval_dataloader + ) + teacher_model.eval() + else: + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader + ) + + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be + # shorter in multiprocess) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Get the metric function + if args.task_name is not None: + metric = load_metric("glue", args.task_name) + else: + metric = load_metric("accuracy") + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + + # Pruning preparation + num_iterations = len(train_dataset) / total_batch_size + num_warm = int(args.sparsity_warm_epochs * num_iterations) + total_iterations = int(num_iterations * (args.num_train_epochs - args.cooldown_epochs)) + frequency = int((total_iterations - num_warm + 1) / 40) if args.pruning_frequency == -1 \ + else args.pruning_frequency + pruning_start = num_warm + pruning_end = total_iterations + if not args.do_prune: + pruning_start = num_iterations * args.num_train_epochs + 1 + pruning_end = pruning_start + pruning_configs = [ + { + "pruning_type": "snip_momentum", + "pruning_scope": "global", + "sparsity_decay_type": "exp", + "excluded_op_names": ["pooler"], + "pruning_op_types": ["Linear"], + "max_sparsity_ratio_per_op": 0.98 + } + ] + configs = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + pattern=args.pruning_pattern, + pruning_frequency=frequency, + start_step=pruning_start, + end_step=pruning_end, + pruning_type=args.pruning_type, + ) + # pruner = Pruning(config) + # pruner.model = model + # pruner.on_train_begin() + from neural_compressor.experimental.compression import prepare_pruning + prepare_pruning(configs, model, optimizer) + + + for epoch in range(args.num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + # pruner.on_step_begin(local_step=step) + + outputs = model(**batch, output_hidden_states=True) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + if args.distill_loss_weight > 0.0: + distill_loss_weight = args.distill_loss_weight + with torch.no_grad(): + teacher_outputs = teacher_model(**batch, output_hidden_states=True) + ##please refer to Knowledge Distillation with the Reused Teacher Classifier https://arxiv.org/abs/2203.14001 + MSELoss = torch.nn.MSELoss().cuda() + loss = distill_loss_weight * MSELoss(outputs['hidden_states'][-1], + teacher_outputs['hidden_states'][-1]) ##variant 3 + + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + # pruner.on_before_optimizer_step() + + optimizer.step() + # pruner.on_after_optimizer_step() + + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + model.eval() + # torch.save(model,"model.pt") + # torch.save(optimizer, "optimizer.pt") + for step, batch in enumerate(eval_dataloader): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() + metric.add_batch( + predictions=accelerator.gather(predictions), + references=accelerator.gather(batch["labels"]), + ) + + eval_metric = metric.compute() + logger.info(f"epoch {epoch}: {eval_metric}") + ##pruner.on_after_eval() + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + # unwrapped_model = accelerator.unwrap_model(model) + # unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + accelerator.save_state(args.output_dir) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + if args.output_dir is not None: + accelerator.wait_for_everyone() + # unwrapped_model = accelerator.unwrap_model(model) + file = os.path.join(args.output_dir, f"epoch{epoch}") + # unwrapped_model.save_pretrained(file) + # unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + accelerator.save_state(file) + # if accelerator.is_main_process: + # tokenizer.save_pretrained(args.output_dir) + # if args.push_to_hub: + # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + + if args.output_dir is not None: + + accelerator.wait_for_everyone() + # unwrapped_model = accelerator.unwrap_model(model) + # unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + accelerator.save_state(args.output_dir) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + if args.task_name == "mnli": + # Final evaluation on mismatched validation set + eval_dataset = processed_datasets["validation_mismatched"] + eval_dataloader = DataLoader( + eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size + ) + eval_dataloader = accelerator.prepare(eval_dataloader) + + model.eval() + for step, batch in enumerate(eval_dataloader): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + metric.add_batch( + predictions=accelerator.gather(predictions), + references=accelerator.gather(batch["labels"]), + ) + + eval_metric = metric.compute() + logger.info(f"mnli-mm: {eval_metric}") + + +if __name__ == "__main__": +# main() + os.environ['CUDA_VISIBLE_DEVICES'] = '1' + import time + from neural_compressor.compression.hpo import prepare_hpo, SearchSpace + from neural_compressor.config import HPOConfig + + args = parse_args() + + search_space = { + 'learning_rate': SearchSpace((0.0001, 0.001)), + 'num_train_epochs': SearchSpace(bound=(20, 100), interval=1), + 'weight_decay': SearchSpace((0.0001, 0.001)), + 'cooldown_epochs': SearchSpace(bound=(0, 10), interval=1), + 'sparsity_warm_epochs': SearchSpace(bound=(0, 5), interval=1), + 'per_device_train_batch_size': SearchSpace((5, 20), 1) + } + config = HPOConfig(search_space=search_space, + searcher='xgb', + higher_is_better=True, + min_train_samples=3) + searcher = prepare_hpo(config) + for iter in range(10): + print(f'search iter {iter}') + st = time.time() + params = searcher.suggest() + for k, v in params.items(): + if k not in ['learning_rate', 'weight_decay']: + v = int(v) + args.__setattr__(k, v) + metric = main(args) + print(metric) + searcher.get_feedback(metric['accuracy']) + acc = metric['accuracy'] + f1 = metric['f1'] + rt = time.time() - st + tmp_str = f'{iter + 10}\t{params}\t{acc}\t{f1}\t{rt}\n' + print(tmp_str) diff --git a/neural_compressor/compression/hpo/__init__.py b/neural_compressor/compression/hpo/__init__.py new file mode 100644 index 00000000000..54dfe3cd14c --- /dev/null +++ b/neural_compressor/compression/hpo/__init__.py @@ -0,0 +1,22 @@ +"""Hyper-parameter Optimization.""" +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .search_space import * +from .sa_optimizer import * +from .search_algorithms import * + diff --git a/neural_compressor/compression/hpo/sa_optimizer.py b/neural_compressor/compression/hpo/sa_optimizer.py new file mode 100644 index 00000000000..ba0c8b7e680 --- /dev/null +++ b/neural_compressor/compression/hpo/sa_optimizer.py @@ -0,0 +1,123 @@ +"""Simulated Annealing Optimizer""" +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import math +from random import random + +import numpy as np + +try: + from neural_compressor.utils import logger +except: + import logging + logger = logging.getLogger("sa_optimizer") + + +class SimulatedAnnealingOptimizer(object): + def __init__( + self, + generate_func=None, + T0=100, + Tf=0.01, + higher_is_better=True, + alpha=None, + iter=500, + early_stop=50, + log_interval=50 + ): + """Initialize.""" + self.generate_func = generate_func + self.T0 = T0 + self.Tf = Tf + self.T = self.T0 + self.higher_is_better = higher_is_better + self.alpha = alpha + self.iter = iter + self.early_stop = early_stop + self.log_interval = log_interval + self.best = (float('-inf'), None) if self.higher_is_better else (float('inf'), None) + self.history = {'T': [], 'F': []} + + def _metrospolis(self, f, f_new): + if (not self.higher_is_better and f_new <= f) or (self.higher_is_better and f_new >= f): + return 1 + else: + if self.higher_is_better: + p = math.exp((f_new - f) / self.T) + else: + p = math.exp((f - f_new) / self.T) + if random() < p: + return 1 + else: + return 0 + + def _generate_new_points(self, points): + new_points = np.array(points) + new_points += self.T * (np.random.random(new_points.shape) - np.random.random(new_points.shape)) + return new_points + + def gen_next_params(self, func, points): + """Get the next parameter.""" + count = 0 + last_modify = 0 + self.T = self.T0 + self.best = (float('-inf'), None) if self.higher_is_better else (float('inf'), None) + scores = func(points) + + self.history = {'T': [], 'F': [], 'P': []} + st = time.time() + + while self.T > self.Tf: + # generate new points + if self.generate_func: + new_points = self.generate_func(points) + else: + new_points = self._generate_new_points(points) + new_scores = func(new_points) + for i, s in enumerate(new_scores): + if self._metrospolis(scores[i], s): + points[i] = new_points[i] + scores[i] = s + if (not self.higher_is_better and scores[i] < self.best[0]) \ + or (self.higher_is_better and scores[i] > self.best[0]): + last_modify = count + self.best = (scores[i], [float(v) for v in points[i]]) + + self.history['T'].append(self.T) + if self.higher_is_better: + self.history['F'].append(max(scores)) + self.history['P'].append(points[np.argmax(scores)]) + else: + self.history['F'].append(min(scores)) + self.history['P'].append(points[np.argmax(scores)]) + + if self.alpha: + self.T *= self.alpha + else: + self.T -= (self.T0 - self.Tf) / (self.iter + 1) + count += 1 + + if self.log_interval and count % self.log_interval == 0: + elapse = time.time() - st + logger.debug(f'SA iter: {count}\tlast_update: {last_modify}\t \ + max score: {self.best[0]}\tpoint: {self.best[1]}\t \ + temp: {self.T}\telasped: {elapse}') + + if count - last_modify > self.early_stop: + break + return self.best[1] diff --git a/neural_compressor/compression/hpo/search_algorithms.py b/neural_compressor/compression/hpo/search_algorithms.py new file mode 100644 index 00000000000..bcdbfed10bd --- /dev/null +++ b/neural_compressor/compression/hpo/search_algorithms.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import xgboost as xgb + +from neural_compressor.strategy.bayesian import BayesianOptimization +from ...config import HPOConfig + +from .search_space import BaseSearchSpace, DiscreteSearchSpace, ContinuousSearchSpace +from .sa_optimizer import SimulatedAnnealingOptimizer + +try: + from neural_compressor.utils import logger +except: + import logging + logger = logging.getLogger(__name__) + + +SEARCHERS = {} + + + +def prepare_hpo(config): + assert isinstance(config, HPOConfig), f'config should be {HPOConfig.__name__}' + assert config.searcher in SEARCHERS.keys(), f"current only support search algorithms: {SEARCHERS.keys()}" + if config.searcher == 'xgb': + return SEARCHERS[config.searcher](config.search_space, + higher_is_better=config.higher_is_better, + loss_type=config.loss_type, + min_train_samples=config.min_train_samples, + seed=config.seed) + else: + return SEARCHERS[config.searcher](config.search_space) + + +def register_searcher(name): + """Class decorator to register a Searcher subclass to the registry. + + Decorator function used before a Pattern subclass. + Make sure that the Searcher class decorated by this function can be registered in SEARCHERS. + + Args: + cls (class): The subclass of register. + name: A string. Define the searcher type. + + Returns: + cls: The class of register. + """ + def register(searcher): + SEARCHERS[name] = searcher + return searcher + return register + + +class Searcher(object): + """Base class for defining the common methods of different search algorithms. + + Args: + search_space (dict): A dictionary for defining the search space. + """ + def __init__(self, search_space): + assert isinstance(search_space, dict) and search_space, \ + "Expect search_space to be a dict." + self.search_space = search_space + self.search_space_keys = sorted(search_space.keys()) + self.search_space_pool = self._create_search_space_pool() + self.best = None + for k in self.search_space_keys: + assert isinstance(self.search_space[k], (list, tuple, BaseSearchSpace)), \ + "Value of key \'{}\' must be a list, tuple,\ + CountinuousSearchSpace or DiscreteSearchSpace to specify choices".format(k) + + def _create_search_space_pool(self): + """Build the search space pool.""" + search_space_pool = [] + for key in self.search_space_keys: + if isinstance(self.search_space[key], (list, tuple)): + space = DiscreteSearchSpace(value=self.search_space[key]) + else: + space = self.search_space[key] + search_space_pool.append(space) + return search_space_pool + + def suggest(self): + """Suggest the model hyperparameter.""" + raise NotImplementedError( + 'Depends on specific search algorithm.') # pragma: no cover + + def get_feedback(self, metric): + """Get metric feedback for the search algorithm.""" + pass + + def params_vec2params_dict(self, para_vec): + """Convert the parameters vector to parameters dictionary. + + Where parameters vector and parameters dictionary both define the model hyperparameter. + + Returns: + Parameters dictionary defining the model hyperparameter. + """ + assert len(para_vec) == len(self.search_space_keys), \ + "Length of para_vec and search_space_keys should be the same." + return {k: para_vec[i] for i, k in enumerate(self.search_space_keys)} + + +@register_searcher("grid") +class GridSearcher(Searcher): + """Grid search. + + Search the whole search space exhaustively. + + Args: + search_space (dict): A dictionary for defining the search space. + """ + def __init__(self, search_space): + """Initialize the attributes.""" + super().__init__(search_space) + + for space in self.search_space_pool: + if space.type == 'continuous': + raise TypeError( + "GridSearcher not support continuous datatype, please use other algorithm." + ) + + self.idx = [0] * len(self.search_space_pool) + + def _add_idx(self, idx=0): + def _add(): + if self.idx[idx] + 1 >= self.search_space_pool[idx].total_num: + return False + else: + self.idx[idx] += 1 + return True + + if idx + 1 == len(self.idx): + return _add() + if self._add_idx(idx + 1): + return True + else: + return _add() + + def suggest(self): + """Suggest the model hyperparameter. + + Returns: + The model hyperparameter. + """ + param = [] + for i in range(len(self.idx)): + param.append(self.search_space_pool[i].get_value(self.idx[i])) + if not self._add_idx(): + logger.warning('run out of search space pool, rebuild...') + self.idx = [0] * len(self.search_space_pool) + return self.params_vec2params_dict(param) + + +@register_searcher("random") +class RandomSearcher(Searcher): + """Random search. + + Search the whole search space randomly. + + Args: + search_space (dict): A dictionary for defining the search space. + """ + def __init__(self, search_space): + """Initialize the attributes.""" + super().__init__(search_space) + + def suggest(self): + """Suggest the model hyperparameter. + + Returns: + The model hyperparameter. + """ + param = [s.get_value() for s in self.search_space_pool] + return self.params_vec2params_dict(param) + + +@register_searcher("bo") +class BayesianOptimizationSearcher(Searcher): + """Bayesian Optimization. + + Search the search space with Bayesian Optimization. + + Args: + search_space (dict): A dictionary for defining the search space. + """ + def __init__(self, search_space, seed=42): + """Initialize the attributes.""" + super().__init__(search_space) + idx_search_space = {} + for key, space in zip(self.search_space_keys, self.search_space_pool): + if isinstance(space, ContinuousSearchSpace): + idx_search_space[key] = tuple(space.bound) + else: + idx_search_space[key] = (0, space.total_num - 1) + self.bo_agent = BayesianOptimization(idx_search_space, + random_seed=seed) + + def suggest(self): + """Suggest the model hyperparameter. + + Returns: + The model hyperparameter. + """ + param_indices = self.bo_agent.gen_next_params() + self.last_param_indices = param_indices + return self.params_vec2params_dict( + self.indices2params_vec(param_indices)) + + def get_feedback(self, metric): + """Get metric feedback and register this metric.""" + assert self.last_param_indices is not None, "Need run suggest first " + \ + "to get parameters and the input metric is corresponding to this parameters." + try: + self.bo_agent._space.register(self.last_param_indices, metric) + except KeyError: # pragma: no cover + logger.debug("Find registered params, skip it.") + pass + if self.best is None or self.best[1] < metric: + param = self.params_vec2params_dict( + self.indices2params_vec(self.last_param_indices)) + self.best = (param, metric) + self.last_param_indices = None + + def feedback(self, param, metric): + if self.best is None or self.best[1] < metric: + self.best = (param, metric) + self.bo_agent._space.register(param, metric) + + def indices2params_vec(self, indices): + """Convert indices to parameters vector.""" + res = [] + for key, ind in indices.items(): + # keep ind within the index range of self.search_space[key] + space = self.search_space_pool[self.search_space_keys.index(key)] + if isinstance(space, ContinuousSearchSpace): + res.append(ind) + else: + ind = int(min(max(round(ind), 0), space.total_num - 1)) + res.append(space.get_value(ind)) + return res + + +@register_searcher("xgb") +class XgbSearcher(Searcher): + """XGBoost searcher. + + Search the search space with XGBoost model. + + Args: + search_space (dict): A dictionary for defining the search space. + """ + def __init__(self, + search_space, + higher_is_better=True, + loss_type='reg', + min_train_samples=10, + seed=42): + """Initialize the attributes.""" + super().__init__(search_space) + + self.seed = seed + self.loss_type = loss_type + self.higher_is_better = higher_is_better + self.min_train_samples = min_train_samples + self.log = {} + + self.last_params = None + self._x = [] + self._y = [] + if loss_type == "reg": + self.model = xgb.XGBRegressor(max_depth=3, + n_estimators=100, + gamma=0.0001, + min_child_weight=1, + subsample=1.0, + eta=0.3, + reg_lambda=1.00, + reg_alpha=0, + objective='reg:squarederror') + elif loss_type == "rank": + self.model = xgb.XGBRanker(max_depth=3, + n_estimators=100, + gamma=0.0001, + min_child_weight=1, + subsample=1.0, + eta=0.3, + reg_lambda=1.00, + reg_alpha=0, + objective='rank:pairwise') + else: # pragma: no cover + raise RuntimeError( + "Invalid loss type: {}, only surport reg and rank".format( + loss_type)) + self.optimizer = SimulatedAnnealingOptimizer( + generate_func=self._generate_new_points, + T0=100, + Tf=0, + alpha=0.9, + higher_is_better=self.higher_is_better) + + def _generate_new_points(self, points): + new_points = [] + for _ in range(len(points)): + new_points.append([s.get_value() for s in self.search_space_pool]) + return new_points + + def suggest(self): + """Suggest the model hyperparameter. + + Returns: + The model hyperparameter. + """ + if len(self._y) < self.min_train_samples: + params = [s.get_value() for s in self.search_space_pool] + else: + x_train, y_train = np.array(self._x), np.array(self._y) + + self.model.fit(x_train, y_train) + params = self.optimizer.gen_next_params(self.model.predict, + self._x) + + self.last_params = params + return self.params_vec2params_dict(params) + + def get_feedback(self, metric): + """Get metric feedback and register this metric.""" + assert self.last_params is not None, "Need run suggest first " + \ + "to get parameters and the input metric is corresponding to this parameters." + if self.best is None or self.best[1] < metric: + self.best = (self.params_vec2params_dict(self.last_params), metric) + self._x.append(self.last_params) + self._y.append(metric) + params_key = '_'.join([str(x) for x in self.last_params]) + self.log[params_key] = metric + self.last_params = None + + def feedback(self, param, metric): + param_list = [] + for k in self.search_space_keys: + param_list.append(param[k]) + if self.best is None or self.best[1] < metric: + self.best = (param, metric) + self._x.append(param_list) + self._y.append(metric) + params_key = '_'.join([str(x) for x in param]) + self.log[params_key] = metric diff --git a/neural_compressor/compression/hpo/search_space.py b/neural_compressor/compression/hpo/search_space.py new file mode 100644 index 00000000000..bab25a886e5 --- /dev/null +++ b/neural_compressor/compression/hpo/search_space.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +SEARCHSPACE = {} + + +class SearchSpace: + """Abstract class, the main entrance for search space. + Using factory pattern to get the actual used class. + + Args: + bound: A tuple or list that limit the max and min number of search space. + interval: Only for descrete search space. Intervals in discrete space. + value: Only for descrete search space. A list that store all the number for search. + type: descrete or continues. + + Example: + from neural_compressor.compression.hpo import SearchSpace + search_space = { + 'learning_rate': SearchSpace((0.0001, 0.001)), + 'num_train_epochs': SearchSpace(bound=(20, 100), interval=1), + 'weight_decay': SearchSpace((0.0001, 0.001), type='continuous') + } + """ + def __new__( + cls, + bound=None, + interval=None, + value=None, + type=None + ): + if type is None: + if interval is not None or value is not None: + type = "discrete" + else: + type = "continuous" + assert type in SEARCHSPACE.keys(), f"only support {list(SEARCHSPACE.keys())}" + return SEARCHSPACE[type](bound, interval, value, type) + + +def register_searchspace(name): + """Class decorator to register a SearchSpace subclass to the registry. + + Args: + cls (class): The subclass of register. + name: A string. Define the pruner type. + + Returns: + cls: The class of register. + """ + def register(search_space): + SEARCHSPACE[name] = search_space + return search_space + + return register + + +class BaseSearchSpace(object): + """Base class for Search Space.""" + def __init__( + self, + bound=None, + interval=None, + value=None, + type=None + ): + """Initialize.""" + if bound: + if not isinstance(bound, (list, tuple)): # pragma: no cover + raise TypeError("bound sould be list or tuple, not {}".format(type(bound))) + if len(bound) != 2: # pragma: no cover + raise ValueError("bound sould only contain two elements, [start, end)") + if bound[1] <= bound[0]: # pragma: no cover + raise ValueError("empty range for [{}, {})".format(bound[0], bound[1])) + assert value or bound, "must set value or bound to initialize the search space" + + self.bound = bound + self.interval = interval + self.value = value + self.type = type + if type == 'discrete': + if value: + self.total_num = len(value) + else: + self.total_num = int((bound[1] - bound[0]) / interval) + else: + self.total_num = float("inf") + + def get_value(self): + """get one value from the search space.""" + pass + + +@register_searchspace("discrete") +class DiscreteSearchSpace(BaseSearchSpace): + """Discrete Search Space.""" + def __init__(self, bound=None, interval=None, value=None, type=None): + if bound and interval is None: + if isinstance(bound[0], int) and isinstance(bound[1], int): + interval = 1 + else: + interval = 0.01 + super().__init__(bound=bound, + interval=interval, + value=value, + type='discrete') + + def get_random_value(self): + """Get a random value from search space.""" + idx = random.randint(0, self.total_num - 1) + return self.get_nth_value(idx) + + def get_nth_value(self, idx): + """Get the number n value from search space.""" + if self.bound: + return round(self.bound[0] + idx * self.interval, 10) + else: + return self.value[idx] + + def get_all(self): + """Get all values from search space.""" + return [self.get_nth_value(i) for i in range(self.total_num)] + + def get_value(self, idx=None): + """Get number n value from search space if idx is given. Otherwise, get a random value.""" + if idx is not None: + if not isinstance(idx, int): + raise TypeError("The type of idx should be int, not {}".format(type(idx))) + if idx < 0: + return self.get_all() + value = self.get_nth_value(idx) + else: + value = self.get_random_value() + return value + + def index(self, value): + """Return the index of the value.""" + if self.value: + return self.value.index(value) + else: + return int((value - self.bound[0]) / self.interval) + + +@register_searchspace("continuous") +class ContinuousSearchSpace(BaseSearchSpace): + """Continuous Search Space.""" + def __init__(self, bound, interval=None, value=None, type=None): + super().__init__(bound, interval, value, "continuous") + + def get_value(self): + """Get one value from the search space.""" + if self.bound[1] > 1: + int_num = random.randrange(int(self.bound[0]), int(self.bound[1]) + 1) + else: + int_num = 0 + while True: + value = random.random() * self.bound[1] + value = int_num + value + if value > self.bound[0] and value < self.bound[1]: + break + return value diff --git a/neural_compressor/compression/pruner/README.md b/neural_compressor/compression/pruner/README.md index 2f3a5c4478e..3561ea7d396 100644 --- a/neural_compressor/compression/pruner/README.md +++ b/neural_compressor/compression/pruner/README.md @@ -52,8 +52,9 @@ Pruning 4. [Sparse Model Deployment](#sparse-model-deployment) +5. [Pruning With HPO](#pruning-with-hyperparameter-optimization) -5. [Reference](#reference) +6. [Reference](#reference) ## Introduction @@ -106,7 +107,10 @@ Pruning patterns defines the rules of pruned weights' arrangements in space. Int - Multi-head Attention Pruning - Multi-head attention mechanism boosts transformer models' capability of contextual information analysis. However, different heads' contribution to the final output varies. In most situation, a number of heads can be removed without causing accuracy drop. Head pruning can be applied in a wide range of scenes including BERT, GPT as well as other large language models. **We have currently support multi-head attention pruning in both pruning and auto slim, which means pruning a set of head first and removing these sparse weights to obtain a lighter model.**. Please refer to [multi-head attention pruning and auto slim examples](https://github.com/intel/neural-compressor/blob/master/examples/pytorch/nlp/huggingface_models/question-answering/model_slim) + Multi-head attention mechanism boosts transformer models' capability of contextual information analysis. However, different heads' contribution to the final output varies. In most situation, a number of heads can be removed without causing accuracy drop. Head pruning can be applied in a wide range of scenes including BERT, GPT as well as other large language models. **We haven't support it in pruning, but we have provided experimental feature in Model Auto Slim**. Please refer to [multi-head attention auto slim examples](https://github.com/intel/neural-compressor/blob/master/examples/pytorch/nlp/huggingface_models/question-answering/model_slim) + + + ### Pruning Criteria @@ -156,7 +160,7 @@ Pruning type defines how the masks are generated and applied to a neural network   (a) refers to the traditional structured iterative pruning;
@@ -383,6 +387,10 @@ Please refer to [pruning examples](../../../examples/README.md#Pruning-1) for mo Particular hardware/software like [Intel Extension for Transformer](https://github.com/intel/intel-extension-for-transformers) are required to obtain inference speed and footprints' optimization for most sparse models. However, using [model slim](#click) for some special structures can obtain significant inference speed improvements and footprint reduction without the post-pruning deployment. In other words, you can achieve model acceleration directly under your training framework (PyTorch, etc.) +## Pruning with Hyperparameter Optimization +Intel® Neural Compressor currently support grid search, random, bayesian optimization and xgboost search algorithms for pruning with HPO. +For more details, please refer to [HPO document](../../neural_compressor/compression/hpo/README.md) + ## Reference [1] Namhoon Lee, Thalaiyasingam Ajanthan, and Philip Torr. SNIP: Single-shot network pruning based on connection sensitivity. In International Conference on Learning Representations, 2019. diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 40dcd40f4da..cc8932e782c 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -1494,6 +1494,32 @@ def weight_compression(self, weight_compression): self._weight_compression = weight_compression +class HPOConfig: + """Config class for hyperparameter optimization. + + Args: + search_space (dict): A dictionary for defining the search space. + searcher(str): The name of search algorithms, currently support: grid, random, bo and xgb. + higher_is_better(bool, optional): This flag indicates whether the metric higher is the better. + min_train_sample(int, optional): The min number of samples to start training the search model. + seed(int, optional): Random seed. + + """ + def __init__(self, + search_space, + searcher='xgb', + higher_is_better=True, + loss_type='reg', + min_train_samples=10, + seed=42): + """Init an HPOConfig object.""" + self.search_space = search_space + self.searcher = searcher + self.higher_is_better = higher_is_better + self.loss_type = loss_type + self.min_train_samples = min_train_samples + self.seed = seed + class KnowledgeDistillationLossConfig: """Config Class for Knowledge Distillation Loss. diff --git a/test/hpo/test_hpo.py b/test/hpo/test_hpo.py new file mode 100644 index 00000000000..5e8d7fc4d70 --- /dev/null +++ b/test/hpo/test_hpo.py @@ -0,0 +1,78 @@ +import unittest +import numpy as np +import sys +sys.path.insert(0, './') +from neural_compressor.config import HPOConfig +from neural_compressor.compression.hpo import (GridSearcher, + DiscreteSearchSpace, + ContinuousSearchSpace, + SearchSpace, + prepare_hpo, + SimulatedAnnealingOptimizer) + + +class TestHPO(unittest.TestCase): + search_space = { + 'learning_rate': SearchSpace((0.0001, 0.001)), + 'num_train_epochs': SearchSpace(bound=(20, 100), interval=1), + 'weight_decay': SearchSpace((0.0001, 0.001)), + 'cooldown_epochs': SearchSpace(bound=(0, 10), interval=1), + 'sparsity_warm_epochs': SearchSpace(bound=(0, 5), interval=1), + 'per_device_train_batch_size': SearchSpace((5, 20), 1) + } + + def test_searcher(self): + hpo_config = HPOConfig({'num_train_epochs': self.search_space['num_train_epochs'], + 'cooldown_epochs': self.search_space['cooldown_epochs']}, searcher='grid') + searcher = GridSearcher({'num_train_epochs': self.search_space['num_train_epochs'], + 'cooldown_epochs': self.search_space['cooldown_epochs']}) + conf_searcher = prepare_hpo(hpo_config) + self.assertEqual(searcher.__class__, conf_searcher.__class__) + for _ in range(5): + self.assertEqual(searcher.suggest(), conf_searcher.suggest()) + hpo_config = HPOConfig(self.search_space, 'random') + searcher = prepare_hpo(hpo_config) + for _ in range(5): + searcher.suggest() + hpo_config = HPOConfig(self.search_space, 'bo') + searcher = prepare_hpo(hpo_config) + for _ in range(10): + searcher.suggest() + searcher.get_feedback(np.random.random()) + hpo_config = HPOConfig(self.search_space, 'xgb', higher_is_better=True, min_train_samples=3) + searcher = prepare_hpo(hpo_config) + for _ in range(5): + searcher.suggest() + searcher.get_feedback(np.random.random()) + for _ in range(5): + param = searcher.suggest() + searcher.feedback(param, np.random.random()) + + def test_search_space(self): + ds = DiscreteSearchSpace(bound=[0, 10]) + get_ds = SearchSpace(bound=[0, 10], interval=1) + self.assertEqual(ds.__class__, get_ds.__class__) + self.assertEqual(ds.index(1), ds.get_nth_value(1)) + ds = DiscreteSearchSpace(value=[1, 2, 3, 4]) + self.assertEqual(ds.get_all(), [1, 2, 3, 4]) + ds = DiscreteSearchSpace(bound=[0.01, 0.1]) + self.assertEqual(ds.interval, 0.01) + self.assertIn(ds.get_value(), ds.get_all()) + self.assertEqual(ds.get_value(2), ds.get_nth_value(2)) + + cs = ContinuousSearchSpace(bound=[0.01, 0.1]) + self.assertTrue(cs.get_value() >= 0.01) + self.assertTrue(cs.get_value() < 0.1) + + def test_sa(self): + def f(x): + return np.mean(np.log(x**2), axis=1) + points = np.random.randn(5, 6) + optimizer = SimulatedAnnealingOptimizer(T0=100, Tf=0, alpha=0.9, higher_is_better=True) + optimizer.gen_next_params(f, points) + optimizer = SimulatedAnnealingOptimizer(T0=1, Tf=0.01, alpha=None, higher_is_better=False) + optimizer.gen_next_params(f, points) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index a8c3a2f39a2..6f7f6cfbf4f 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -18,3 +18,4 @@ onnxruntime-extensions; python_version < '3.10' dynast==1.3.0 intel-extension-for-pytorch tf2onnx +xgboost