diff --git a/examples/README.md b/examples/README.md index d0c3686b5d9..0584d215a31 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,6 +6,8 @@ Intel® Neural Compressor validated examples with multiple compression technique # Quick Get Started Notebook Examples * [Quick Get Started Notebook of Intel® Neural Compressor for ONNXRuntime](/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb) +* [Quick Get Started Notebook of Intel® Neural Compressor for Tensorflow](examples/notebook/tensorflow/resnet/resnet_quantization.ipynb) + # Helloworld Examples * [tf_example1](/examples/helloworld/tf_example1): quantize with built-in dataloader and metric. diff --git a/examples/notebook/tensorflow/resnet/README.md b/examples/notebook/tensorflow/resnet/README.md new file mode 100644 index 00000000000..93938c41b60 --- /dev/null +++ b/examples/notebook/tensorflow/resnet/README.md @@ -0,0 +1,3 @@ +# Notebook of using Intel Neural Compressor to do int8 quantization on ResNet + +You can follow the steps in `resnet_quantization.ipynb` to see how to quantize ResNet step by step. \ No newline at end of file diff --git a/examples/notebook/tensorflow/resnet/resnet_benchmark.py b/examples/notebook/tensorflow/resnet/resnet_benchmark.py new file mode 100644 index 00000000000..4f6c1ac648b --- /dev/null +++ b/examples/notebook/tensorflow/resnet/resnet_benchmark.py @@ -0,0 +1,79 @@ +import tensorflow as tf +import numpy as np +from tqdm import tqdm +import datasets +from datasets import load_dataset +import argparse +import os + +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--input_model", type=str, required=True) +args = parser.parse_args() + +# load dataset in streaming way will get an IterableDatset +calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True) +eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, use_auth_token=True) + +MAX_SAMPLE_LENGTG=1000 +def sample_data(dataset, max_sample_length): + data = {"image": [], "label": []} + for i, record in enumerate(dataset): + if i >= MAX_SAMPLE_LENGTG: + break + data["image"].append(record['image']) + data["label"].append(record['label']) + return datasets.Dataset.from_dict(data) + +sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG) + +from neural_compressor.data.transforms.imagenet_transform import TensorflowResizeCropImagenetTransform +height = width = 224 +transform = TensorflowResizeCropImagenetTransform(height, width) + + +class CustomDataloader: + def __init__(self, dataset, batch_size=1): + '''dataset is a iterable dataset and will be loaded record by record at runtime.''' + self.dataset = dataset + self.batch_size = batch_size + import math + self.length = math.ceil(len(self.dataset) / self.batch_size) + + def __iter__(self): + batch_inputs = [] + labels = [] + for idx, record in enumerate(self.dataset): + # record e.g.: {'image': , 'label': 91} + img = record['image'] + label = record['label'] + # skip the wrong shapes + if len(np.array(img).shape) != 3 or np.array(img).shape[-1] != 3: + continue + img_resized = transform((img, label)) # (img, label) + batch_inputs.append(np.array(img_resized[0])) + labels.append(label) + if (idx+1) % self.batch_size == 0: + yield np.array(batch_inputs), np.array(labels) # (bs, 224, 224, 3), (bs,) + batch_inputs = [] + labels = [] + def __len__(self): + return self.length + +from neural_compressor import quantization +from neural_compressor.config import PostTrainingQuantConfig +from neural_compressor.utils.create_obj_from_config import create_dataloader + +eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=1) + + +from neural_compressor.benchmark import fit +from neural_compressor.config import BenchmarkConfig + + +conf = BenchmarkConfig(iteration=100, + cores_per_instance=4, + num_of_instance=1) +bench_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=1) + + +fit(args.input_model, conf, b_dataloader=bench_dataloader) diff --git a/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb new file mode 100644 index 00000000000..8f1c36ef3d8 --- /dev/null +++ b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intel® Neural Compressor Sample for Tensorflow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This is a demo to show how to use Intel® Neural Compressor to do quantization on ResNet." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare Environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!conda install python==3.10 -y\n", + "!pip install neural-compressor\n", + "!wget -nc https://storage.googleapis.com/intel-optimized-tensorflow/models/v1_6/resnet50_fp32_pretrained_model.pb\n", + "!pip install tensorflow\n", + "!pip install datasets\n", + "!pip install git+https://github.com/huggingface/huggingface_hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "import datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# login to huggingface to download the imagenet-1k dataset\n", + "# you should replace this read-only token with your own by create one on (https://huggingface.co/settings/tokens)\n", + "# !huggingface-cli login --token \n", + "!huggingface-cli login --token hf_xxxxxxxxxxxxxxxxxxxxxx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "# load dataset in streaming way will get an IterableDatset\n", + "calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True)\n", + "eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, use_auth_token=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can select only a subset of the dataset for demo, here just select 1k samples\n", + "MAX_SAMPLE_LENGTG=1000\n", + "def sample_data(dataset, max_sample_length):\n", + " data = {\"image\": [], \"label\": []}\n", + " for i, record in enumerate(dataset):\n", + " if i >= MAX_SAMPLE_LENGTG:\n", + " break\n", + " data[\"image\"].append(record['image'])\n", + " data[\"label\"].append(record['label'])\n", + " return datasets.Dataset.from_dict(data)\n", + "\n", + "sub_calib_dataset = sample_data(calib_dataset, MAX_SAMPLE_LENGTG)\n", + "sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from neural_compressor.data.transforms.imagenet_transform import TensorflowResizeCropImagenetTransform\n", + "height = width = 224\n", + "transform = TensorflowResizeCropImagenetTransform(height, width)\n", + "\n", + "class CustomDataloader:\n", + " def __init__(self, dataset, batch_size=1):\n", + " '''dataset is a iterable dataset and will be loaded record by record at runtime.'''\n", + " self.dataset = dataset\n", + " self.batch_size = batch_size\n", + " import math\n", + " self.length = math.ceil(len(self.dataset) / self.batch_size)\n", + " \n", + " def __iter__(self):\n", + " batch_inputs = []\n", + " labels = []\n", + " for idx, record in enumerate(self.dataset):\n", + " # record e.g.: {'image': , 'label': 91}\n", + " img = record['image']\n", + " label = record['label']\n", + " # skip the wrong shapes\n", + " if len(np.array(img).shape) != 3 or np.array(img).shape[-1] != 3:\n", + " continue\n", + " img_resized = transform((img, label)) # (img, label)\n", + " batch_inputs.append(np.array(img_resized[0]))\n", + " labels.append(label)\n", + " if (idx+1) % self.batch_size == 0:\n", + " yield np.array(batch_inputs), np.array(labels) # (bs, 224, 224, 3), (bs,)\n", + " batch_inputs = []\n", + " labels = []\n", + " def __len__(self):\n", + " return self.length" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "calib_dataloader = CustomDataloader(dataset=sub_calib_dataset, batch_size=32)\n", + "eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quantization\n", + "\n", + "Then we are moving to the core quantization logics. `quantization.fit` is the main entry of converting our base model to the quantized model. We pass the prepared calibration and evaluation dataloder to `quantization.fit`. After converting, we obtain the quantized int8 model and save it locally. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "import time\n", + "from neural_compressor import quantization\n", + "from neural_compressor.config import PostTrainingQuantConfig\n", + "\n", + "conf = PostTrainingQuantConfig(calibration_sampling_size=[50, 100], excluded_precisions = ['bf16'])\n", + "\n", + "def eval_func(model):\n", + " from neural_compressor.model import Model\n", + " model = Model(model)\n", + " ans = []\n", + " total_cnt = 0\n", + " total_hit = 0\n", + " latency_list = []\n", + " for idx, (batch_inputs, labels) in enumerate(tqdm(eval_dataloader)):\n", + " feed_dict = dict(zip(model.input_tensor, [batch_inputs]))\n", + " start = time.time()\n", + " preds = model.sess.run(model.output_tensor, feed_dict)\n", + " end = time.time()\n", + " latency_list.append(end-start)\n", + " ans = np.argmax(preds[0], axis=-1)\n", + " labels += 1 # label shift\n", + " total_cnt += len(labels)\n", + " total_hit += np.sum(ans == labels)\n", + " acc = total_hit / total_cnt\n", + " latency = np.array(latency_list).mean() / eval_dataloader.batch_size\n", + " return acc\n", + "\n", + "q_model = quantization.fit(\"./resnet50_fp32_pretrained_model.pb\", conf=conf, calib_dataloader=calib_dataloader, eval_func=eval_func)\n", + "q_model.save(\"resnet50_int8.pb\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark\n", + "\n", + "Now we can see that we have two models under the current directory: the original fp32 model `resnet50_fp32_pretrained_model.pb` and the quantized int8 model `resnet50_int8.pb`, and then we are going to do performance comparisons between them.\n", + "\n", + "\n", + "To avoid the conflicts of jupyter notebook kernel to our benchmark process. We create a `resnet_quantization.py` and run it directly to do the benchmarks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### FP32 benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python resnet_benchmark.py --input_model resnet50_fp32_pretrained_model.pb 2>&1|tee fp32_benchmark.log" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### INT8 benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python resnet_benchmark.py --input_model resnet50_int8.pb 2>&1|tee int8_benchmark.log" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, you will get the performance in the logs like following:\n", + "\n", + "* fp32_benchmark.log\n", + "\n", + "```\n", + "2023-08-28 22:46:39 [INFO] ********************************************\n", + "2023-08-28 22:46:39 [INFO] |****Multiple Instance Benchmark Summary*****|\n", + "2023-08-28 22:46:39 [INFO] +---------------------------------+----------+\n", + "2023-08-28 22:46:39 [INFO] | Items | Result |\n", + "2023-08-28 22:46:39 [INFO] +---------------------------------+----------+\n", + "2023-08-28 22:46:39 [INFO] | Latency average [second/sample] | 0.027209 |\n", + "2023-08-28 22:46:39 [INFO] | Throughput sum [samples/second] | 36.753 |\n", + "2023-08-28 22:46:39 [INFO] +---------------------------------+----------+\n", + "```\n", + "\n", + "* int8_benchmark.log\n", + "\n", + "```\n", + "2023-08-28 22:48:35 [INFO] ********************************************\n", + "2023-08-28 22:48:35 [INFO] |****Multiple Instance Benchmark Summary*****|\n", + "2023-08-28 22:48:35 [INFO] +---------------------------------+----------+\n", + "2023-08-28 22:48:35 [INFO] | Items | Result |\n", + "2023-08-28 22:48:35 [INFO] +---------------------------------+----------+\n", + "2023-08-28 22:48:35 [INFO] | Latency average [second/sample] | 0.006855 |\n", + "2023-08-28 22:48:35 [INFO] | Throughput sum [samples/second] | 145.874 |\n", + "2023-08-28 22:48:35 [INFO] +---------------------------------+----------+\n", + "```\n", + "\n", + "As shown in the logs, the int8/fp32 performance gain is about 145.87/36.75 = 3.97x" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spycsh-neuralchat", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}