Skip to content

Commit

Permalink
Added script for automatic conversion to float16, excluding the minim… (
Browse files Browse the repository at this point in the history
#193)

* Added script for automatic conversion to float16, excluding the minimum number of nodes

* Renamed auto_float16 to auto_mixed_precision and added rtol/atol attributes

* Move code template to docstring
  • Loading branch information
TomWildenhain-Microsoft authored Jul 1, 2021
1 parent b716f75 commit b2215a0
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 0 deletions.
207 changes: 207 additions & 0 deletions onnxconverter_common/auto_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###########################################################################

"""
This tool converts converts a model to mixed precision (float32->float16) while excluding nodes as needed to maintain
a certain accuracy.
Example usage:
from onnxconverter_common import auto_mixed_precision
import onnx
model = onnx.load(model_path)
# Could also use rtol/atol attributes directly instead of this
def validate(res1, res2):
for r1, r2 in zip(res1, res2):
if not np.allclose(r1, r2, rtol=0.01, atol=0.001):
return False
return True
model_fp16 = auto_convert_mixed_precision(model, test_data, validate, keep_io_types=True)
onnx.save(model_fp16, "ouptut_path")
"""

import onnxruntime as ort
import onnx
import numpy as np
from onnxconverter_common import float16
from onnx import helper, mapping
import copy


def auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol=None, keep_io_types=False):
"""
Automatically converts a model to mixed precision, excluding the minimum number of nodes required to
ensure valudate_fn returns True and/or results are equal according to rtol/atol
"""
if rtol is None and atol is not None:
rtol = 1e-5

if atol is None and rtol is not None:
atol = 1e-8

def validate(res1, res2):
if validate_fn is not None and not validate_fn(res1, res2):
return False
if rtol is not None:
for r1, r2 in zip(res1, res2):
if not np.allclose(r1, r2, rtol, atol):
return False
return True

model0 = onnx.shape_inference.infer_shapes(model)
model0 = add_missing_dtypes_using_ort(model0, feed_dict)
res0 = get_tensor_values_using_ort(model0, feed_dict)
if not keep_io_types:
feed_dict = {k: v.astype(np.float16) if v.dtype == np.float32 else v for k, v in feed_dict.items()}
if not validate(res0, res0):
raise ValueError("validation failed for original fp32 model")
node_names = [n.name for n in model0.graph.node if n.op_type not in ["Loop", "If", "Scan"]]

def run_attempt(node_block_list, return_model=False):
print(node_block_list)
model = float16.convert_float_to_float16(copy.deepcopy(model0), node_block_list=node_block_list,
keep_io_types=keep_io_types, disable_shape_infer=True)
res1 = get_tensor_values_using_ort(model, feed_dict)
if return_model:
return validate(res0, res1), model
else:
valid = validate(res0, res1)
print(valid)
return valid

if not run_attempt(node_names):
raise ValueError("validation failed for model with all nodes in node_block_list")
print("Sanity checks passed. Starting autoconvert.")
segments = SegmentList(node_names)
i = 0
while segments.get_largest() is not None:
seg = segments.get_largest()
nodes_to_try = segments.get_nodes(seg)
i += 1
print("Running attempt %d excluding conversion of %s nodes" % (i, len(nodes_to_try)))
if run_attempt(nodes_to_try):
seg.good = True
print("Attempt succeeded.")
else:
print("Attempt failed.")
if seg.size == 1:
seg.bad = True
else:
seg.split()
print(segments)
print("Done:", segments.get_nodes())
valid, model = run_attempt(segments.get_nodes(), return_model=True)
if not valid:
raise ValueError("validation failed for final fp16 model")
print("Final model validated successfully.")
return model


def add_missing_dtypes_using_ort(model, feed_dict, outputs_per_iter=100):
outputs = [out for node in model.graph.node for out in node.output]
graph_io = [inp.name for inp in model.graph.input] + [out.name for out in model.graph.output]
value_info_names = [info.name for info in model.graph.value_info]
skip = set(graph_io + value_info_names)
outputs = [out for out in outputs if out not in skip]
print("Adding missing dtypes for %s outputs" % len(outputs))
out_to_dtype = {}
i = 0
while i < len(outputs):
outs = outputs[i:i + outputs_per_iter]
vals = get_tensor_values_using_ort(model, feed_dict, outs)
for out, val in zip(outs, vals):
out_to_dtype[out] = mapping.NP_TYPE_TO_TENSOR_TYPE[val.dtype]
i += outputs_per_iter
for out, dtype in out_to_dtype.items():
model.graph.value_info.append(helper.make_tensor_value_info(out, dtype, shape=None))
return model


def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_options=None):
if output_names is None:
sess = ort.InferenceSession(model.SerializeToString(), sess_options)
return sess.run(None, input_feed)
original_outputs = list(model.graph.output)
while len(model.graph.output) > 0:
model.graph.output.pop()
for n in output_names:
out = model.graph.output.add()
out.name = n
sess = ort.InferenceSession(model.SerializeToString(), sess_options)
try:
return sess.run(output_names, input_feed)
finally:
while len(model.graph.output) > 0:
model.graph.output.pop()
for orig_out in original_outputs:
out = model.graph.output.add()
out.CopyFrom(orig_out)


class SegmentList:
def __init__(self, node_names):
self.node_names = node_names
self.first = NodeSegment(len(node_names))

def get_largest(self, adjacent_to_good=False):
adjacent_to_good = False
largest = None
current = self.first
prev_good = False
while current is not None:
can_use = not current.good and not current.bad
if adjacent_to_good:
next_good = current.next is not None and current.next.good
can_use = can_use and (prev_good or next_good)
if can_use and (largest is None or current.size > largest.size):
largest = current
prev_good = current.good
current = current.next
return largest

def get_nodes(self, node_segment=None):
i = 0
current = self.first
nodes = []
while current is not None:
if current is not node_segment and not current.good:
nodes.extend(self.node_names[i:i + current.size])
i += current.size
current = current.next
return nodes

def __repr__(self):
res = []
current = self.first
while current is not None:
res.append(current)
current = current.next
return repr(res)


class NodeSegment:
def __init__(self, size):
self.size = size
self.next = None
self.good = False
self.bad = False

def split(self):
new_size = self.size // 2
new_segment = NodeSegment(self.size - new_size)
new_segment.next = self.next
self.next = new_segment
self.size = new_size

def __repr__(self):
if self.good:
return "*" + str(self.size) + "*"
if self.bad:
return "(" + str(self.size) + ")"
return str(self.size)
63 changes: 63 additions & 0 deletions tests/test_auto_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
import numpy as np
import onnxruntime as _ort
import onnx
import copy
from onnxconverter_common.onnx_fx import Graph, OnnxOperatorBuilderX
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
from onnxconverter_common.onnx_ex import get_maximum_opset_supported
from onnxconverter_common.auto_mixed_precision import auto_convert_mixed_precision


def _ort_inference(mdl, inputs):
sess = _ort.InferenceSession(mdl.SerializeToString())
return sess.run(None, inputs)


Graph.inference_runtime = _ort_inference
Graph.opset = 9
onnx_function = Graph.trace

@unittest.skipIf(get_maximum_opset_supported() < 9, "tests designed for ONNX opset 9 and greater")
@unittest.skipIf(not hasattr(onnx, "shape_inference"), "shape inference is required")
class AutoFloat16Test(unittest.TestCase):
def test_auto_mixed_precision(self):
@onnx_function(outputs=['z'],
input_types=(_Ty.F([1, 1, 6, 1])),
output_types=[_Ty.f])
def transpose_n_matmul(x):
ox = x.ox # type: OnnxOperatorBuilderX
wm = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).astype(np.float32).reshape([2, 6])
b = ox.constant(value=wm)
a = ox.transpose(x, perm=[0, 1, 3, 2])
c = ox.transpose(b, perm=[1, 0])
m = ox.matmul([a, c])
m_large = ox.mul([m, ox.constant(value=np.array(100, np.float32))])
m_xlarge = ox.mul([m_large, ox.constant(value=np.array(10, np.float32))])
mr = ox.reshape([m_xlarge], desired_shape=[2])
mr = ox.reshape([mr], desired_shape=[2])
m_normal = ox.div([mr, ox.constant(value=np.array(999, np.float32))])
return m_normal

m1 = np.array([[2, 3], [4, 5], [6, 7]]).astype(np.float32).reshape([1, 1, 6, 1])
expected = transpose_n_matmul(m1)
model = transpose_n_matmul.to_model()

def validate_fn(res, fp16res):
return np.allclose(res[0], fp16res[0], rtol=0.01)

f16model = auto_convert_mixed_precision(copy.deepcopy(model), {'x': m1}, validate_fn, keep_io_types=True)

actual = _ort_inference(f16model, {'x': m1})
self.assertTrue(np.allclose(expected, actual, rtol=0.01))

f16model2 = auto_convert_mixed_precision(copy.deepcopy(model), {'x': m1}, rtol=0.01, keep_io_types=False)

actual = _ort_inference(f16model2, {'x': m1.astype(np.float16)})
self.assertTrue(np.allclose(expected, actual, rtol=0.01))


if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromTestCase(AutoFloat16Test)
# suite.debug()
unittest.TextTestRunner().run(suite)

0 comments on commit b2215a0

Please sign in to comment.