From ae6651a87450ae890752ffabb0d44b04633835bf Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Mon, 24 Jun 2024 14:37:13 +0800 Subject: [PATCH] Update Example for Pytorch 3x Mixed Precision Signed-off-by: zehao-intel --- ...T_MixPrecision.md => PT_MixedPrecision.md} | 30 +- examples/.config/model_params_pytorch_3x.json | 7 + .../pytorch/cv/mixed_precision/README.md | 47 +++ .../pytorch/cv/mixed_precision/main.py | 367 ++++++++++++++++++ .../cv/mixed_precision/requirements.txt | 4 + .../cv/mixed_precision/run_autotune.sh | 45 +++ .../cv/mixed_precision/run_benchmark.sh | 86 ++++ neural_compressor/common/utils/constants.py | 2 +- .../__init__.py | 4 +- .../half_precision_convert.py | 12 +- .../module_wrappers.py | 0 .../torch/quantization/__init__.py | 6 +- .../torch/quantization/algorithm_entry.py | 16 +- .../torch/quantization/config.py | 40 +- 14 files changed, 615 insertions(+), 51 deletions(-) rename docs/3x/{PT_MixPrecision.md => PT_MixedPrecision.md} (52%) create mode 100644 examples/3.x_api/pytorch/cv/mixed_precision/README.md create mode 100644 examples/3.x_api/pytorch/cv/mixed_precision/main.py create mode 100644 examples/3.x_api/pytorch/cv/mixed_precision/requirements.txt create mode 100644 examples/3.x_api/pytorch/cv/mixed_precision/run_autotune.sh create mode 100644 examples/3.x_api/pytorch/cv/mixed_precision/run_benchmark.sh rename neural_compressor/torch/algorithms/{mix_precision => mixed_precision}/__init__.py (74%) rename neural_compressor/torch/algorithms/{mix_precision => mixed_precision}/half_precision_convert.py (86%) rename neural_compressor/torch/algorithms/{mix_precision => mixed_precision}/module_wrappers.py (100%) diff --git a/docs/3x/PT_MixPrecision.md b/docs/3x/PT_MixedPrecision.md similarity index 52% rename from docs/3x/PT_MixPrecision.md rename to docs/3x/PT_MixedPrecision.md index c1cd198049b..a8b62d866a4 100644 --- a/docs/3x/PT_MixPrecision.md +++ b/docs/3x/PT_MixedPrecision.md @@ -8,13 +8,17 @@ PyTorch Mixed Precision ## Introduction -The recent growth of Deep Learning has driven the development of more complex models that require significantly more compute and memory capabilities. Several low precision numeric formats have been proposed to address the problem. Google's [bfloat16](https://cloud.google.com/tpu/docs/bfloat16) and the [FP16: IEEE](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) half-precision format are two of the most widely used sixteen bit formats. [Mixed precision](https://arxiv.org/abs/1710.03740) training and inference using low precision formats have been developed to reduce compute and bandwidth requirements. +The recent growth of Deep Learning has driven the development of more complex models that require significantly more compute and memory capabilities. Several low precision numeric formats have been proposed to address the problem. +Google's [bfloat16](https://cloud.google.com/tpu/docs/bfloat16) and the [FP16: IEEE](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) half-precision format are two of the most widely used sixteen bit formats. [Mixed precision](https://arxiv.org/abs/1710.03740) training and inference using low precision formats have been developed to reduce compute and bandwidth requirements. -The 3rd Gen Intel® Xeon® Scalable processor (codenamed Cooper Lake), featuring Intel® Deep Learning Boost, is the first general-purpose x86 CPU to support the bfloat16 format. Specifically, three new bfloat16 instructions are added as a part of the AVX512_BF16 extension within Intel Deep Learning Boost: VCVTNE2PS2BF16, VCVTNEPS2BF16, and VDPBF16PS. The first two instructions allow converting to and from bfloat16 data type, while the last one performs a dot product of bfloat16 pairs. Further details can be found in the [hardware numerics document](https://www.intel.com/content/www/us/en/developer/articles/technical/intel-deep-learning-boost-new-instruction-bfloat16.html) published by Intel. +The 3rd Gen Intel® Xeon® Scalable processor (codenamed Cooper Lake), featuring Intel® Deep Learning Boost, is the first general-purpose x86 CPU to support the bfloat16 format. Specifically, three new bfloat16 instructions are added as a part of the AVX512_BF16 extension within Intel Deep Learning Boost: VCVTNE2PS2BF16, VCVTNEPS2BF16, and VDPBF16PS. The first two instructions allow converting to and from bfloat16 data type, while the last one performs a dot product of bfloat16 pairs. +Further details can be found in the [Hardware Numerics Document](https://www.intel.com/content/www/us/en/developer/articles/technical/intel-deep-learning-boost-new-instruction-bfloat16.html) published by Intel. -The 4th Gen Intel® Xeon® Scalable processor supports FP16 instruction set architecture (ISA) for Intel® -Advanced Vector Extensions 512 (Intel® AVX-512). The new ISA supports a wide range of general-purpose numeric -operations for 16-bit half-precision IEEE-754 floating-point and complements the existing 32-bit and 64-bit floating-point instructions already available in the Intel Xeon processor based products. Further details can be found in the [hardware numerics document](https://www.intel.com/content/www/us/en/content-details/669773/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide.html) published by Intel. +The 4th Gen Intel® Xeon® Scalable processor supports FP16 instruction set architecture (ISA) for Intel® Advanced Vector Extensions 512 (Intel® AVX-512). The new ISA supports a wide range of general-purpose numeric operations for 16-bit half-precision IEEE-754 floating-point and complements the existing 32-bit and 64-bit floating-point instructions already available in the Intel Xeon processor based products. +Further details can be found in the [Intel AVX512 FP16 Guide](https://www.intel.com/content/www/us/en/content-details/669773/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide.html) published by Intel. + +The latest Intel Xeon processors deliver flexibility of Intel Advanced Matrix Extensions (Intel AMX) ,an accelerator that improves the performance of deep learning(DL) training and inference, making it ideal for workloads like NLP, recommender systems, and image recognition. Developers can code AI functionality to take advantage of the Intel AMX instruction set, and they can code non-AI functionality to use the processor instruction set architecture (ISA). Intel has integrated the Intel® oneAPI Deep Neural Network Library (oneDNN), its oneAPI DL engine, into Pytorch. +Further details can be found in the [Intel AMX Document](https://www.intel.com/content/www/us/en/content-details/785250/accelerate-artificial-intelligence-ai-workloads-with-intel-advanced-matrix-extensions-intel-amx.html) published by Intel.

Architecture @@ -58,6 +62,9 @@ operations for 16-bit half-precision IEEE-754 floating-point and complements the - PyTorch 1. Hardware: CPU supports `avx512_fp16` instruction set. 2. Software: torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html). +> Note: To run FP16 on Intel-AMX, please set the environment variable `ONEDNN_MAX_CPU_ISA`: +> ```export ONEDNN_MAX_CPU_ISA=AVX512_CORE_AMX_FP16``` + ### Accuracy-driven mixed precision @@ -68,36 +75,37 @@ To be noticed, IPEX backend doesn't support accuracy-driven mixed precision. ## Get Started with autotune API -To get a bf16/fp16 model, users can use the `autotune` interface with `MixPrecisionConfig` as follows. +To get a bf16/fp16 model, users can use the `autotune` interface with `MixedPrecisionConfig` as follows. - BF16: ```python -from neural_compressor.torch.quantization import MixPrecisionConfig, TuningConfig, autotune +from neural_compressor.torch.quantization import MixedPrecisionConfig, TuningConfig, autotune def eval_acc_fn(model): ...... return acc # modules might be fallback to fp32 to get better accuracy -custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["bf16", "fp32"])], max_trials=3) +custom_tune_config = TuningConfig(config_set=[MixedPrecisionConfig(dtype=["bf16", "fp32"])], max_trials=3) best_model = autotune(model=build_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) ``` - FP16: ```python -from neural_compressor.torch.quantization import MixPrecisionConfig, TuningConfig, autotune +from neural_compressor.torch.quantization import MixedPrecisionConfig, TuningConfig, autotune def eval_acc_fn(model): ...... return acc # modules might be fallback to fp32 to get better accuracy -custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "fp32"])], max_trials=3) +custom_tune_config = TuningConfig(config_set=[MixedPrecisionConfig(dtype=["fp16", "fp32"])], max_trials=3) best_model = autotune(model=build_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) ``` ## Examples -Example will be added later. +Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch\cv\mixed_precision +) on how to quantize a model with Mixed Precision. diff --git a/examples/.config/model_params_pytorch_3x.json b/examples/.config/model_params_pytorch_3x.json index bbbab60bdbc..15cd28907fe 100644 --- a/examples/.config/model_params_pytorch_3x.json +++ b/examples/.config/model_params_pytorch_3x.json @@ -146,6 +146,13 @@ "input_model": "", "main_script": "run_clm_no_trainer.py", "batch_size": 1 + }, + "resnet18_mixed_precision": { + "model_src_dir": "cv/mixed_precision", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "resnet18", + "main_script": "main.py", + "batch_size": 100 } } } diff --git a/examples/3.x_api/pytorch/cv/mixed_precision/README.md b/examples/3.x_api/pytorch/cv/mixed_precision/README.md new file mode 100644 index 00000000000..597c9e2a3fb --- /dev/null +++ b/examples/3.x_api/pytorch/cv/mixed_precision/README.md @@ -0,0 +1,47 @@ +Step-by-Step +============ + +This document describes the step-by-step instructions for reproducing PyTorch ResNet18 MixedPrecision results with Intel® Neural Compressor. + +# Prerequisite + +### 1. Environment + +PyTorch 1.8 or higher version is needed with pytorch_fx backend. + +```Shell +cd examples/3.x_api/pytorch/image_recognition/torchvision_models/mixed_precision/resnet18 +pip install -r requirements.txt +``` +> Note: Validated PyTorch [Version](/docs/source/installation_guide.md#validated-software-environment). + +### 2. Prepare Dataset + +Download [ImageNet](http://www.image-net.org/) Raw image to dir: /path/to/imagenet. The dir includes below folder: + +```bash +ls /path/to/imagenet +train val +``` + +# Run + +> Note: All torchvision model names can be passed as long as they are included in `torchvision.models`, below are some examples. + +## MixedPrecision +```Shell +bash run_autotune.sh --input_model=resnet18 --dataset_location=/path/to/imagenet +``` + +## Benchmark +```Shell +# run optimized performance +bash run_benchmark.sh --input_model=resnet18 --dataset_location=/path/to/imagenet --mode=performance --batch_size=100 --optimized=true --iters=500 +# run optimized accuracy +bash run_benchmark.sh --input_model=resnet18 --dataset_location=/path/to/imagenet --mode=accuracy --batch_size=1 --optimized=true +``` + + + + + diff --git a/examples/3.x_api/pytorch/cv/mixed_precision/main.py b/examples/3.x_api/pytorch/cv/mixed_precision/main.py new file mode 100644 index 00000000000..8ef798f9ac3 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/mixed_precision/main.py @@ -0,0 +1,367 @@ +import argparse +import os +import random +import shutil +import time +import warnings +import sys + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +model_names = models.list_models(module=models) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('-t', '--tune', dest='tune', action='store_true', + help='tune best optimized model') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--ppn', default=1, type=int, + help='number of processes on each node of distributed training') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('-i', "--iter", default=0, type=int, + help='For accuracy measurement only.') +parser.add_argument('-w', "--warmup_iter", default=5, type=int, + help='For benchmark measurement only.') +parser.add_argument('--performance', dest='performance', action='store_true', + help='run benchmark') +parser.add_argument('-r', "--accuracy", dest='accuracy', action='store_true', + help='For accuracy measurement only.') +parser.add_argument("--tuned_checkpoint", default='./saved_results', type=str, metavar='PATH', + help='path to checkpoint tuned by Neural Compressor (default: ./)') +parser.add_argument('--optimized', dest='optimized', action='store_true', + help='run benchmark') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if 'mobilenet_v2' in args.arch: + import torchvision.models.quantization as models + else: + import torchvision.models as models + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + def eval_func(model): + accu = validate(val_loader, model, criterion, args) + return float(accu) + + if args.tune: + from neural_compressor.torch.quantization import MixedPrecisionConfig, TuningConfig, autotune + custom_tune_config = TuningConfig(config_set=[MixedPrecisionConfig(dtype=["fp16", "fp32"])]) + best_model = autotune(model=model, tune_config=custom_tune_config, eval_fn=eval_func) + torch.save(best_model, args.tuned_checkpoint) + return + + if args.performance or args.accuracy: + model.eval() + if args.optimized: + new_model = torch.load(args.tuned_checkpoint) + else: + new_model = model + if args.performance or args.accuracy: + validate(val_loader, new_model, criterion, args) + return + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, + top5, prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.print(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + latency_list = [] + for i, (input, target) in enumerate(val_loader): + if i >= args.warmup_iter: + start = time.time() + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + perf_start = time.time() + output = model(input) + perf_end = time.time() + latency_list.append(perf_end-perf_start) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # measure elapsed time + if i >= args.warmup_iter: + batch_time.update(time.time() - start) + + if i % args.print_freq == 0: + progress.print(i) + + if args.iter > 0 and i >= (args.warmup_iter + args.iter - 1): + break + + if args.accuracy: + print('Batch size = %d' % args.batch_size) + print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}' + .format(top1=(top1.avg / 100), top5=(top5.avg / 100))) + if args.performance: + latency = np.array(latency_list[args.warmup_iter:]).mean() / args.batch_size + print("Batch size = {}".format(args.batch_size)) + print("Latency: {:.3f} ms".format(latency * 1000)) + print("Throughput: {:.3f} images/sec".format(1. / latency)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, *meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def print(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/examples/3.x_api/pytorch/cv/mixed_precision/requirements.txt b/examples/3.x_api/pytorch/cv/mixed_precision/requirements.txt new file mode 100644 index 00000000000..46233c08f4a --- /dev/null +++ b/examples/3.x_api/pytorch/cv/mixed_precision/requirements.txt @@ -0,0 +1,4 @@ +neural-compressor +torch>=1.9.0 +torchvision>=0.10.0 +accelerate diff --git a/examples/3.x_api/pytorch/cv/mixed_precision/run_autotune.sh b/examples/3.x_api/pytorch/cv/mixed_precision/run_autotune.sh new file mode 100644 index 00000000000..770671db180 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/mixed_precision/run_autotune.sh @@ -0,0 +1,45 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + tuned_checkpoint=saved_results + batch_size=30 + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + extra_cmd="${dataset_location}" + python main.py \ + -a ${input_model}\ + -t\ + --pretrained\ + ${extra_cmd} +} + +main "$@" diff --git a/examples/3.x_api/pytorch/cv/mixed_precision/run_benchmark.sh b/examples/3.x_api/pytorch/cv/mixed_precision/run_benchmark.sh new file mode 100644 index 00000000000..e3e8fbd1f00 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/mixed_precision/run_benchmark.sh @@ -0,0 +1,86 @@ +#!/bin/bash +set -x + +function main { + + export ONEDNN_MAX_CPU_ISA=AVX512_CORE_AMX_FP16 + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + tuned_checkpoint=saved_results + batch_size=30 + for var in "$@" + do + case $var in + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --optimized=*) + optimized=$(echo ${var} |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + if [[ ${mode} == "accuracy" ]]; then + mode_cmd=" --accuracy" + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --iter ${iters} --performance " + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + + if [[ ${optimized} == "true" ]]; then + extra_cmd="--optimized ${dataset_location}" + else + extra_cmd="${dataset_location}" + fi + if [[ ${mode} == "accuracy" ]]; then + python main.py \ + --pretrained \ + --tuned_checkpoint ${tuned_checkpoint} \ + -b ${batch_size} \ + -a ${input_model} \ + ${mode_cmd} \ + ${extra_cmd} + elif [[ ${mode} == "performance" ]]; then + incbench --num_c 4 --num_i 7 python main.py \ + --pretrained \ + --tuned_checkpoint ${tuned_checkpoint} \ + -b ${batch_size} \ + -a ${input_model} \ + ${mode_cmd} \ + ${extra_cmd} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi +} + +main "$@" diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index 629a3f5743e..86a4c342f58 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -37,7 +37,7 @@ AUTOROUND = "autoround" FP8_QUANT = "fp8_quant" MX_QUANT = "mx_quant" -MIX_PRECISION = "mix_precision" +MIXED_PRECISION = "mixed_precision" # options import datetime diff --git a/neural_compressor/torch/algorithms/mix_precision/__init__.py b/neural_compressor/torch/algorithms/mixed_precision/__init__.py similarity index 74% rename from neural_compressor/torch/algorithms/mix_precision/__init__.py rename to neural_compressor/torch/algorithms/mixed_precision/__init__.py index 084e1c44e0f..cdd28b3260e 100644 --- a/neural_compressor/torch/algorithms/mix_precision/__init__.py +++ b/neural_compressor/torch/algorithms/mixed_precision/__init__.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter -from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper +from neural_compressor.torch.algorithms.mixed_precision.half_precision_convert import HalfPrecisionConverter +from neural_compressor.torch.algorithms.mixed_precision.module_wrappers import HalfPrecisionModuleWrapper diff --git a/neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py b/neural_compressor/torch/algorithms/mixed_precision/half_precision_convert.py similarity index 86% rename from neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py rename to neural_compressor/torch/algorithms/mixed_precision/half_precision_convert.py index 951eb4cb0b4..4a3c33d40f4 100644 --- a/neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py +++ b/neural_compressor/torch/algorithms/mixed_precision/half_precision_convert.py @@ -21,7 +21,7 @@ import torch from neural_compressor.common import logger -from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper +from neural_compressor.torch.algorithms.mixed_precision.module_wrappers import HalfPrecisionModuleWrapper from neural_compressor.torch.utils import get_accelerator @@ -37,7 +37,7 @@ def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs): """Initialize the Half-precision Converter with config. Args: - configs_mapping (Dict): config class for mix-precision. + configs_mapping (Dict): config class for mixed-precision. """ self.configs_mapping = configs_mapping self.device = get_accelerator().current_device_name() @@ -49,7 +49,7 @@ def convert(self, model: torch.nn.Module): model (torch.nn.Module): the input model. Returns: - mix_precision_model (torch.nn.Module): model with mix-precision. + mixed_precision_model (torch.nn.Module): model with mixed-precision. """ if len(self.configs_mapping) > 0: logger.info("Convert operators to half-precision") @@ -59,10 +59,10 @@ def convert(self, model: torch.nn.Module): elif next(model.parameters()).is_cpu: self.device = "cpu" - mix_precision_model = self._wrap_half_precision_model(model) - mix_precision_model.to(self.device) + mixed_precision_model = self._wrap_half_precision_model(model) + mixed_precision_model.to(self.device) - return mix_precision_model + return mixed_precision_model def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""): """Wrap and replace half-precision target modules. diff --git a/neural_compressor/torch/algorithms/mix_precision/module_wrappers.py b/neural_compressor/torch/algorithms/mixed_precision/module_wrappers.py similarity index 100% rename from neural_compressor/torch/algorithms/mix_precision/module_wrappers.py rename to neural_compressor/torch/algorithms/mixed_precision/module_wrappers.py diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 3bc12580848..4e70d82843d 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -36,9 +36,9 @@ get_default_fp8_config_set, MXQuantConfig, get_default_mx_config, - MixPrecisionConfig, - get_default_mix_precision_config, - get_default_mix_precision_config_set, + MixedPrecisionConfig, + get_default_mixed_precision_config, + get_default_mixed_precision_config_set, get_woq_tuning_config, DynamicQuantConfig, get_default_dynamic_config, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 733e4409b91..21e0bd65a4d 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -24,7 +24,7 @@ FP8_QUANT, GPTQ, HQQ, - MIX_PRECISION, + MIXED_PRECISION, MX_QUANT, RTN, SMOOTH_QUANT, @@ -38,7 +38,7 @@ FP8Config, GPTQConfig, HQQConfig, - MixPrecisionConfig, + MixedPrecisionConfig, MXQuantConfig, RTNConfig, SmoothQuantConfig, @@ -567,14 +567,14 @@ def mx_quant_entry( ###################### Mixed Precision Algo Entry ################################## -@register_algo(MIX_PRECISION) -def mix_precision_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs +@register_algo(MIXED_PRECISION) +def mixed_precision_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixedPrecisionConfig], *args, **kwargs ) -> torch.nn.Module: # only support fp16 and bf16 now, more types might be added later - from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter + from neural_compressor.torch.algorithms.mixed_precision import HalfPrecisionConverter half_precision_converter = HalfPrecisionConverter(configs_mapping, *args, **kwargs) - mix_precision_model = half_precision_converter.convert(model) + mixed_precision_model = half_precision_converter.convert(model) - return mix_precision_model + return mixed_precision_model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27a056d3284..5ae6c9a5789 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -36,7 +36,7 @@ FP8_QUANT, GPTQ, HQQ, - MIX_PRECISION, + MIXED_PRECISION, MX_QUANT, OP_NAME_OR_MODULE_TYPE, RTN, @@ -1321,12 +1321,12 @@ def get_default_fp8_config_set() -> FP8Config: return FP8Config.get_config_set_for_tuning() -######################## MixPrecision Config ############################### -@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION) -class MixPrecisionConfig(BaseConfig): - """Config class for mix-precision.""" +######################## MixedPrecision Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIXED_PRECISION) +class MixedPrecisionConfig(BaseConfig): + """Config class for mixed-precision.""" - name = MIX_PRECISION + name = MIXED_PRECISION supported_configs: List[OperatorConfig] = [] params_list = [ "dtype", @@ -1343,7 +1343,7 @@ def __init__( dtype: Union[str, List[str]] = "fp16", white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): - """Init MixPrecision config. + """Init MixedPrecision config. Args: """ @@ -1354,16 +1354,16 @@ def __init__( @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] - mix_precision_config = MixPrecisionConfig( + mixed_precision_config = MixedPrecisionConfig( dtype=["fp16", "bf16", "fp32"], ) operators = cls.supported_half_precision_ops - supported_configs.append(OperatorConfig(config=mix_precision_config, operators=operators)) + supported_configs.append(OperatorConfig(config=mixed_precision_config, operators=operators)) cls.supported_configs = supported_configs @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = tuple(MixPrecisionConfig.supported_half_precision_ops) + white_list = tuple(MixedPrecisionConfig.supported_half_precision_ops) filter_result = [] for op_name, module in model.named_modules(): if isinstance(module, white_list): @@ -1373,27 +1373,27 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: return filter_result @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "MixPrecisionConfig", List["MixPrecisionConfig"]]: + def get_config_set_for_tuning(cls) -> Union[None, "MixedPrecisionConfig", List["MixedPrecisionConfig"]]: # TODO fwk owner needs to update it. - return MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"]) + return MixedPrecisionConfig(dtype=["fp16", "bf16", "fp32"]) -def get_default_mix_precision_config() -> MixPrecisionConfig: - """Generate the default mix-precision config. +def get_default_mixed_precision_config() -> MixedPrecisionConfig: + """Generate the default mixed-precision config. Returns: - the default mix-precision config. + the default mixed-precision config. """ - return MixPrecisionConfig() + return MixedPrecisionConfig() -def get_default_mix_precision_config_set() -> MixPrecisionConfig: - """Generate the default mix-precision config set. +def get_default_mixed_precision_config_set() -> MixedPrecisionConfig: + """Generate the default mixed-precision config set. Returns: - the default mix-precision config. + the default mixed-precision config. """ - return MixPrecisionConfig.get_config_set_for_tuning() + return MixedPrecisionConfig.get_config_set_for_tuning() ##################### Algo Configs End ###################################