Skip to content

Commit

Permalink
Merge pull request #4 from lp6m/dev
Browse files Browse the repository at this point in the history
update README.md and quantize script
  • Loading branch information
lp6m authored Aug 19, 2021
2 parents 46bf245 + fc0bb96 commit b6741a9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ docker run -it --gpus all -v `pwd`:/workspace yolov5s_anrdoid bash
* Converted TfLite Model.

## Performance
### Latency (inference)
### Latency
These results are measured on `Xiaomi Mi11`.
Please refer [`benchmark/README.md`](https://github.com/lp6m/yolov5s_android/tree/master/benchmark) about the detail of benchmark command.
The latency does not contain the pre/post processing time and data transfer time.
Expand Down
2 changes: 1 addition & 1 deletion convert_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ netron yolov5s.onnx

In this model, the output layer IDs are `Conv_245,Conv_325,Conv_405`.
**We convert the ONNX model without detect head layers.**
### Why we exclude detect head layers?https://github.com/onnx/onnx-tensorflow
### Why we exclude detect head layers?
NNAPI does not support some layers included in detect head layers.
For example, The number of dimension supported by [ANEURALNETWORKS_MUL](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ab34ca99890c827b536ce66256a803d7a) operator for multiply layer is up to 4.
The input of multiply layer in detect head layers has 5 dimension, so NNAPI delegate cannot load the model.
Expand Down
56 changes: 56 additions & 0 deletions convert_model/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import sys
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

def quantize_model(INPUT_SIZE, pb_path, output_path, calib_num, tfds_root, download_flag):
raw_test_data = tfds.load(name='coco/2017',
with_info=False,
split='validation',
data_dir=tfds_root,
download=download_flag)
input_shapes = [(3, INPUT_SIZE, INPUT_SIZE)]
def representative_dataset_gen():
for i, data in enumerate(raw_test_data.take(calib_num)):
print('calibrating...', i)
image = data['image'].numpy()
images = []
for shape in input_shapes:
data = tf.image.resize(image, (shape[1], shape[2]))
tmp_image = data / 255.
tmp_image = tmp_image[np.newaxis,:,:,:]
images.append(tmp_image)
yield images

input_arrays = ['inputs']
output_arrays = ['Identity', 'Identity_1', 'Identity_2']
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(pb_path, input_arrays, output_arrays)
converter.experimental_new_quantizer = False
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.allow_custom_ops = False
converter.inference_input_type = tf.uint8
# To commonalize postprocess, output_type is float32
converter.inference_output_type = tf.float32
converter.representative_dataset = representative_dataset_gen
tflite_model = converter.convert()
with open(output_path, 'wb') as w:
w.write(tflite_model)
print('Quantization Completed!', output_path)

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input_size', type=int, default=640)
parser.add_argument('--pb_path', default="/workspace/yolov5/tflite/model_float32.pb")
parser.add_argument('--output_path', default='/workspace/yolov5/tflite/model_quantized.tflite')
parser.add_argument('--calib_num', type=int, default=100, help='number of images for calibration.')
parser.add_argument('--tfds_root', default='/workspace/TFDS/')
parser.add_argument('--download_tfds', action='store_true', help='download tfds. it takes a lot of time.')
args = parser.parse_args()
quantize_model(args.input_size, args.pb_path, args.output_path, args.calib_num, args.tfds_root, args.download_tfds)


41 changes: 0 additions & 41 deletions convert_model/replace.json

This file was deleted.

0 comments on commit b6741a9

Please sign in to comment.