-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
unify onnx examples prepare model scripts (#1187)
Signed-off-by: Sun, Xuehao <[email protected]>
- Loading branch information
Showing
143 changed files
with
4,156 additions
and
1,540 deletions.
There are no files selected for viewing
19 changes: 5 additions & 14 deletions
19
...s/onnxrt/body_analysis/onnx_model_zoo/arcface/quantization/ptq_static/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
...ples/onnxrt/body_analysis/onnx_model_zoo/arcface/quantization/ptq_static/prepare_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import argparse | ||
import os | ||
import sys | ||
from urllib import request | ||
|
||
import onnx | ||
from onnx import version_converter | ||
|
||
MODEL_URL = "https://github.com/onnx/models/raw/main/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx" | ||
MAX_TIMES_RETRY_DOWNLOAD = 5 | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_model", type=str, required=False, default='arcfaceresnet100-8.onnx') | ||
parser.add_argument("--output_model", type=str, required=True) | ||
return parser.parse_args() | ||
|
||
|
||
def progressbar(cur, total=100): | ||
percent = '{:.2%}'.format(cur / total) | ||
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent)) | ||
sys.stdout.flush() | ||
|
||
|
||
def schedule(blocknum, blocksize, totalsize): | ||
if totalsize == 0: | ||
percent = 0 | ||
else: | ||
percent = min(1.0, blocknum * blocksize / totalsize) * 100 | ||
progressbar(percent) | ||
|
||
|
||
def download_model(url, model_name, retry_times=5): | ||
if os.path.isfile(model_name): | ||
print(f"{model_name} exists, skip download") | ||
return True | ||
|
||
print("download model...") | ||
retries = 0 | ||
while retries < retry_times: | ||
try: | ||
request.urlretrieve(url, model_name, schedule) | ||
break | ||
except KeyboardInterrupt: | ||
return False | ||
except: | ||
retries += 1 | ||
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}") | ||
return retries < retry_times | ||
|
||
|
||
def export_model(input_model, output_model): | ||
# Convert opset version to 14 for more quantization capability. | ||
print("\nexport model...") | ||
model = onnx.load(input_model) | ||
model = version_converter.convert_version(model, 14) | ||
onnx.save_model(model, output_model) | ||
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!" | ||
|
||
|
||
def prepare_model(input_model, output_model): | ||
# Download model from [ONNX Model Zoo](https://github.com/onnx/models). | ||
is_download_successful = download_model(MODEL_URL, input_model, MAX_TIMES_RETRY_DOWNLOAD) | ||
if is_download_successful: | ||
export_model(input_model, output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
prepare_model(args.input_model, args.output_model) |
19 changes: 5 additions & 14 deletions
19
.../body_analysis/onnx_model_zoo/emotion_ferplus/quantization/ptq_static/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
...xrt/body_analysis/onnx_model_zoo/emotion_ferplus/quantization/ptq_static/prepare_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import argparse | ||
import os | ||
import sys | ||
from urllib import request | ||
|
||
import onnx | ||
from onnx import version_converter | ||
|
||
MODEL_URL = "https://github.com/onnx/models/raw/main/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx" | ||
MAX_TIMES_RETRY_DOWNLOAD = 5 | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_model", type=str, required=False, default='emotion-ferplus-8.onnx') | ||
parser.add_argument("--output_model", type=str, required=True) | ||
return parser.parse_args() | ||
|
||
|
||
def progressbar(cur, total=100): | ||
percent = '{:.2%}'.format(cur / total) | ||
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent)) | ||
sys.stdout.flush() | ||
|
||
|
||
def schedule(blocknum, blocksize, totalsize): | ||
if totalsize == 0: | ||
percent = 0 | ||
else: | ||
percent = min(1.0, blocknum * blocksize / totalsize) * 100 | ||
progressbar(percent) | ||
|
||
|
||
def download_model(url, model_name, retry_times=5): | ||
if os.path.isfile(model_name): | ||
print(f"{model_name} exists, skip download") | ||
return True | ||
|
||
print("download model...") | ||
retries = 0 | ||
while retries < retry_times: | ||
try: | ||
request.urlretrieve(url, model_name, schedule) | ||
break | ||
except KeyboardInterrupt: | ||
return False | ||
except: | ||
retries += 1 | ||
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}") | ||
return retries < retry_times | ||
|
||
|
||
def export_model(input_model, output_model): | ||
# Convert opset version to 14 for more quantization capability. | ||
print("\nexport model...") | ||
model = onnx.load(input_model) | ||
model = version_converter.convert_version(model, 14) | ||
onnx.save_model(model, output_model) | ||
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!" | ||
|
||
|
||
def prepare_model(input_model, output_model): | ||
# Download model from [ONNX Model Zoo](https://github.com/onnx/models). | ||
is_download_successful = download_model(MODEL_URL, input_model, MAX_TIMES_RETRY_DOWNLOAD) | ||
if is_download_successful: | ||
export_model(input_model, output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
prepare_model(args.input_model, args.output_model) |
19 changes: 5 additions & 14 deletions
19
...onnxrt/body_analysis/onnx_model_zoo/ultraface/quantization/ptq_static/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
...es/onnxrt/body_analysis/onnx_model_zoo/ultraface/quantization/ptq_static/prepare_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import argparse | ||
import os | ||
import sys | ||
from urllib import request | ||
|
||
import onnx | ||
from onnx import version_converter | ||
|
||
MODEL_URL = "https://github.com/onnx/models/raw/main/vision/body_analysis/ultraface/models/version-RFB-320.onnx" | ||
MAX_TIMES_RETRY_DOWNLOAD = 5 | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_model", type=str, required=False, default='version-RFB-640.onnx') | ||
parser.add_argument("--output_model", type=str, required=True) | ||
return parser.parse_args() | ||
|
||
|
||
def progressbar(cur, total=100): | ||
percent = '{:.2%}'.format(cur / total) | ||
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent)) | ||
sys.stdout.flush() | ||
|
||
|
||
def schedule(blocknum, blocksize, totalsize): | ||
if totalsize == 0: | ||
percent = 0 | ||
else: | ||
percent = min(1.0, blocknum * blocksize / totalsize) * 100 | ||
progressbar(percent) | ||
|
||
|
||
def download_model(url, model_name, retry_times=5): | ||
if os.path.isfile(model_name): | ||
print(f"{model_name} exists, skip download") | ||
return True | ||
|
||
print("download model...") | ||
retries = 0 | ||
while retries < retry_times: | ||
try: | ||
request.urlretrieve(url, model_name, schedule) | ||
break | ||
except KeyboardInterrupt: | ||
return False | ||
except: | ||
retries += 1 | ||
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}") | ||
return retries < retry_times | ||
|
||
|
||
def export_model(input_model, output_model): | ||
# Convert opset version to 14 for more quantization capability. | ||
print("\nexport model...") | ||
model = onnx.load(input_model) | ||
model = version_converter.convert_version(model, 14) | ||
onnx.save_model(model, output_model) | ||
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!" | ||
|
||
|
||
def prepare_model(input_model, output_model): | ||
# Download model from [ONNX Model Zoo](https://github.com/onnx/models). | ||
is_download_successful = download_model(MODEL_URL, input_model, MAX_TIMES_RETRY_DOWNLOAD) | ||
if is_download_successful: | ||
export_model(input_model, output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
prepare_model(args.input_model, args.output_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.