Skip to content

Commit

Permalink
Enhance ONNXRT FP16 cast (#1220)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho authored Sep 7, 2023
1 parent 779fd32 commit a1b566f
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
20 changes: 16 additions & 4 deletions neural_compressor/adaptor/ox_utils/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,22 @@ def cast_inputs(self, node, cfg, indices=None):
if initializer is not None:
if initializer.data_type != onnx_proto.TensorProto.FLOAT:
continue
do_cast = cast_tensor(initializer, cfg)
if do_cast:
self.new_value_info[tensor_name] = ValueInfo(tensor_name,
TensorProto.FLOAT, dtype_mapping[cfg])
do_cast_new_tensor = cast_tensor(initializer, cfg, self.model.is_large_model)
if do_cast_new_tensor:
# add cast initializer and update its name
self.model.add_initializer(do_cast_new_tensor)
node.input[idx] = do_cast_new_tensor.name

# if origin initializer is no more used, remove it
self.model.update()
input_name_to_nodes = self.model.input_name_to_nodes
if initializer.name not in input_name_to_nodes or \
len(input_name_to_nodes[initializer.name]) == 0:
self.model.remove_initializer(initializer)

self.new_value_info[do_cast_new_tensor.name] = ValueInfo(do_cast_new_tensor.name,
TensorProto.FLOAT,
dtype_mapping[cfg])
else:
if tensor_name in self.value_infos and \
self.value_infos[tensor_name].type.HasField('tensor_type') and \
Expand Down
28 changes: 16 additions & 12 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,18 @@ def float_to_bfloat16(tensor):
return tensor


def cast_tensor(tensor, dtype): # pragma: no cover
def cast_tensor(tensor, dtype, is_large_model=False): # pragma: no cover
"""Convert tensor float to target dtype.
Args:
tensor (TensorProto): TensorProto object
dtype (int): target data type
is_large_model (bool): if is large model, make tensor with raw=True
"""
if not isinstance(tensor, onnx_proto.TensorProto):
raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor))

new_tensor = None
if tensor.data_type == onnx_proto.TensorProto.FLOAT:
val = numpy_helper.to_array(tensor).copy()
if dtype == "fp16":
Expand All @@ -196,21 +198,23 @@ def cast_tensor(tensor, dtype): # pragma: no cover
new_val = float_to_bfloat16(val)
else:
raise ValueError("Expect fp16 or bf16 but get {}.".format(dtype))
try:

if not is_large_model:
new_tensor = helper.make_tensor(
name=tensor.name,
name=tensor.name + "_init_cast",
data_type=dtype_mapping[dtype],
dims=numpy_helper.to_array(tensor).shape if len(numpy_helper.to_array(tensor).shape) != 0 else [],
vals=new_val if len(numpy_helper.to_array(tensor)) != 0 else [numpy_helper.to_array(tensor)],
vals=new_val if len(numpy_helper.to_array(tensor).shape) != 0 else [numpy_helper.to_array(tensor)],
)
tensor.CopyFrom(new_tensor)
except:
tensor.float_data[:] = []
tensor.int32_data[:] = []
tensor.raw_data = new_val.tostring()
tensor.data_type = dtype_mapping[dtype]
return True
return False
else:
new_tensor = helper.make_tensor(
name=tensor.name + "_init_cast",
data_type=dtype_mapping[dtype],
dims=numpy_helper.to_array(tensor).shape if len(numpy_helper.to_array(tensor).shape) != 0 else [],
vals=new_val.tostring(),
raw=True,
)
return new_tensor


def remove_init_from_model_input(model):
Expand Down
49 changes: 49 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,55 @@ def test_bf16(self):
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=["DnnlExecutionProvider"])
outputs = session.run(None, input_data)

def test_fp16_with_repeated_init(self):
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [])
repeated_init = helper.make_tensor("repeated_init", TensorProto.FLOAT, [], [0])

less_node = onnx.helper.make_node("Less", ["input", "repeated_init"], ["less_output"], name="Less")
cast_node = onnx.helper.make_node("Cast", ["less_output"], ["cast_output"], name="Cast", to=TensorProto.FLOAT)
clip_output = helper.make_tensor_value_info("clip_output", TensorProto.FLOAT, [])
clip_node = onnx.helper.make_node("Clip", ["cast_output", "repeated_init"], ["clip_output"], name="Clip")

initializers = [repeated_init]
graph = helper.make_graph(
[less_node, cast_node, clip_node],
"test_fp16_with_repeated_init_model",
[input_tensor],
[clip_output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7

from neural_compressor import MixedPrecisionConfig
from neural_compressor.mix_precision import fit

config = MixedPrecisionConfig(
backend="onnxrt_cuda_ep",
device="gpu",
precision="fp16",
op_type_dict={"Clip": {"activation": {"dtype": ["fp32"]}, "weight": {"dtype": ["fp32"]}}},
)
converted_model = fit(model, config)
less_init = converted_model.get_node("Less").input[1]
clip_1_init = converted_model.get_node("Clip").input[1]
self.assertEqual(clip_1_init + "_init_cast", less_init)
self.assertTrue("Cast" in set([i.op_type for i in converted_model.nodes()]))
self.assertTrue(10 in set([i.attribute[0].i for i in converted_model.nodes() if i.op_type == "Cast"]))

config = MixedPrecisionConfig(
backend="onnxrt_cuda_ep",
device="gpu",
precision="fp16",
)
converted_model = fit(model, config)
less_init = converted_model.get_node("Less").input[1]
clip_1_init = converted_model.get_node("Clip").input[1]
self.assertTrue(less_init.endswith("_init_cast"))
self.assertTrue(clip_1_init.endswith("_init_cast"))
self.assertTrue("Cast" in set([i.op_type for i in converted_model.nodes()]))
self.assertTrue(10 in set([i.attribute[0].i for i in converted_model.nodes() if i.op_type == "Cast"]))


if __name__ == "__main__":
unittest.main()

0 comments on commit a1b566f

Please sign in to comment.