Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support >2GB ONNX models for fp16 conversion #167

Merged
merged 3 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxconverter_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
from .utils import * # noqa F403
from .case_insensitive_dict import * # noqa F403
from .metadata_props import add_metadata_props, set_denotation
from .float16 import convert_tensor_float_to_float16, convert_float_to_float16
from .float16 import convert_tensor_float_to_float16, convert_float_to_float16, convert_float_to_float16_model_path
from .optimizer import optimize_onnx, optimize_onnx_graph, optimize_onnx_model
42 changes: 40 additions & 2 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,14 @@ def make_value_info_from_tensor(tensor):
return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape)


def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False):
def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4,
keep_io_types=False, disable_shape_infer=False):
'''
yetingqiaqia marked this conversation as resolved.
Show resolved Hide resolved
Convert tensor float type in the ONNX ModelProto input to tensor float16.

:param model: ONNX ModelProto object
:param disable_shape_infer: Type/shape information is needed for conversion to work.
Set to True only if the model already has type/shape information for all tensors.
:return: converted ONNX ModelProto object

Examples:
Expand All @@ -102,7 +105,7 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, k

'''
func_infer_shape = None
if onnx.__version__ >= '1.2':
if not disable_shape_infer and onnx.__version__ >= '1.2':
try:
from onnx.shape_inference import infer_shapes
func_infer_shape = infer_shapes
Expand Down Expand Up @@ -259,3 +262,38 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, k
node.output[i] = input_name
break
return model


def convert_float_to_float16_model_path(model_path, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False):
'''
Convert tensor float type in the ONNX Model to tensor float16.
*It is to fix an issue that infer_shapes func cannot be used to infer >2GB models.
*But this function can be applied to all model sizes.
:param model_path: ONNX Model path
:return: converted ONNX ModelProto object
Examples
::
#Convert to ONNX ModelProto object and save model binary file:
from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path
new_onnx_model = convert_float_to_float16_model_path('model.onnx')
onnx.save(new_onnx_model, 'new_model.onnx')
'''

disable_shape_infer = False
if onnx.__version__ >= '1.7':
try:
# infer_shapes_path can be applied to all model sizes
from onnx.shape_inference import infer_shapes_path
import tempfile
import os
# shape_infer_model_path should be in the same folder of model_path
with tempfile.NamedTemporaryFile(dir=os.path.dirname(model_path)) as tmpfile:
shape_infer_model_path = tmpfile.name
infer_shapes_path(model_path, shape_infer_model_path)
model = onnx.load(shape_infer_model_path)
disable_shape_infer = True
finally:
yetingqiaqia marked this conversation as resolved.
Show resolved Hide resolved
pass
if not disable_shape_infer:
model = onnx.load(model_path)
return convert_float_to_float16(model, min_positive_val, max_finite_val, keep_io_types, disable_shape_infer)