diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index b719cb6cfc794..9823e8264e17b 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -138,7 +138,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "ConstantOfShape": self._infer_ConstantOfShape, "Conv": self._infer_Conv, "CumSum": self._pass_on_shape_and_type, - "DequantizeLinear": self._infer_DequantizeLinear, "Div": self._infer_symbolic_compute_ops, "Einsum": self._infer_Einsum, "Expand": self._infer_Expand, @@ -164,7 +163,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "NonZero": self._infer_NonZero, "OneHot": self._infer_OneHot, "Pad": self._infer_Pad, - "QuantizeLinear": self._infer_QuantizeLinear, "Range": self._infer_Range, "Reciprocal": self._pass_on_shape_and_type, "ReduceSum": self._infer_ReduceSum, @@ -199,6 +197,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "BiasGelu": self._infer_BiasGelu, "BiasSplitGelu": self._infer_BiasSplitGelu, "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, "FastGelu": self._infer_FastGelu, "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, @@ -214,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, @@ -459,6 +459,8 @@ def _onnx_infer_single_node(self, node): "GemmFastGelu", "LayerNormalization", "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", "RelativePositionBias", "RemovePadding", "RestorePadding", @@ -982,10 +984,27 @@ def _infer_NhwcConv(self, node): # noqa: N802 ) def _infer_DequantizeLinear(self, node): # noqa: N802 - self._propagate_shape_and_type(node) + # Get the output data type from the scale input (index 1, required). + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_QuantizeLinear(self, node): # noqa: N802 - self._propagate_shape_and_type(node) + # Get the output data type from the zero-point input (index 2, optional). + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_Einsum(self, node): # noqa: N802 # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275