Skip to content

Commit

Permalink
Properly propagate the data types for Q/DQ ops
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Jan 31, 2024
1 parent f7e14db commit c9a1dad
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"ConstantOfShape": self._infer_ConstantOfShape,
"Conv": self._infer_Conv,
"CumSum": self._pass_on_shape_and_type,
"DequantizeLinear": self._infer_DequantizeLinear,
"Div": self._infer_symbolic_compute_ops,
"Einsum": self._infer_Einsum,
"Expand": self._infer_Expand,
Expand All @@ -164,7 +163,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"NonZero": self._infer_NonZero,
"OneHot": self._infer_OneHot,
"Pad": self._infer_Pad,
"QuantizeLinear": self._infer_QuantizeLinear,
"Range": self._infer_Range,
"Reciprocal": self._pass_on_shape_and_type,
"ReduceSum": self._infer_ReduceSum,
Expand Down Expand Up @@ -199,6 +197,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"BiasGelu": self._infer_BiasGelu,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"DequantizeLinear": self._infer_DequantizeLinear,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
Expand All @@ -214,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PythonOp": self._infer_PythonOp,
"QuantizeLinear": self._infer_QuantizeLinear,
"QuickGelu": self._infer_FastGelu,
"RelativePositionBias": self._infer_RelativePositionBias,
"RemovePadding": self._infer_RemovePadding,
Expand Down Expand Up @@ -459,6 +459,8 @@ def _onnx_infer_single_node(self, node):
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"DequantizeLinear",
"QuantizeLinear",
"RelativePositionBias",
"RemovePadding",
"RestorePadding",
Expand Down Expand Up @@ -982,10 +984,27 @@ def _infer_NhwcConv(self, node): # noqa: N802
)

def _infer_DequantizeLinear(self, node): # noqa: N802
self._propagate_shape_and_type(node)
# Get the output data type from the scale input (index 1, required).
output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type

# Get the output shape from the first input.
output_shape = self._get_shape(node, 0)

vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))

def _infer_QuantizeLinear(self, node): # noqa: N802
self._propagate_shape_and_type(node)
# Get the output data type from the zero-point input (index 2, optional).
# Otherwise, default to uint8
output_dtype = onnx.TensorProto.UINT8
if len(node.input) > 2 and node.input[2]:
output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type

# Get the output shape from the first input.
output_shape = self._get_shape(node, 0)

vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))

def _infer_Einsum(self, node): # noqa: N802
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
Expand Down

0 comments on commit c9a1dad

Please sign in to comment.