Skip to content

Commit

Permalink
move test case to util
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang committed Feb 7, 2022
1 parent c1a8f0d commit 7a32847
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ static bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, con
const auto& op_type = node_unit.OpType();
auto quant_op_type = GetQuantizedOpType(node_unit);

ORT_ENFORCE(quant_op_type != QuantizedOpType::QLinearMatMul, "[", op_type, "] is not a quantized op");
ORT_ENFORCE(quant_op_type != QuantizedOpType::Unknown, "[", op_type, "] is not a quantized op");

bool is_quant_conv = IsQuantizedConv(quant_op_type);
bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul);
Expand Down
48 changes: 48 additions & 0 deletions onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,53 @@ GetQDQTestCaseFn BuildQDQResizeTestCase(const std::vector<int64_t>& input_shape,
const std::string& mode = "nearest",
const std::string& coordinate_transformation_mode = "half_pixel");

template <typename Input1Type, typename Input2Type, typename OutputType>
GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector<int64_t>& input_shape,
const std::string& op_type) {
return [input_shape, op_type](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* input2_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();

// add QDQ 1
auto* q1_output = builder.MakeIntermediate();
auto* dq1_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input1Type>(input1_arg,
.004f,
std::numeric_limits<Input1Type>::max() / 2,
q1_output);
builder.AddDequantizeLinearNode<Input1Type>(q1_output,
.0039f,
std::numeric_limits<Input1Type>::max() / 2,
dq1_output);

// add QDQ 2
auto* q2_output = builder.MakeIntermediate();
auto* dq2_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input2Type>(input2_arg,
.004f,
std::numeric_limits<Input2Type>::max() / 2,
q2_output);
builder.AddDequantizeLinearNode<Input2Type>(q2_output,
.0039f,
std::numeric_limits<Input2Type>::max() / 2,
dq2_output);

// add binary operator
auto* binary_op_output = builder.MakeIntermediate();
builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output});

// add QDQ output
auto* q3_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<OutputType>(binary_op_output,
.0038f,
std::numeric_limits<OutputType>::max() / 2,
q3_output);
builder.AddDequantizeLinearNode<OutputType>(q3_output,
.0039f,
std::numeric_limits<OutputType>::max() / 2,
output_arg);
};
}
} // namespace test
} // namespace onnxruntime
50 changes: 2 additions & 48 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,51 +290,6 @@ TEST(QDQTransformerTests, AveragePool_U8S8) {
template <typename Input1Type, typename Input2Type, typename OutputType>
void QDQTransformerBinaryOpTests(const std::string& op_type) {
auto test_case = [&](const std::vector<int64_t>& input_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* input2_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();

// add QDQ 1
auto* q1_output = builder.MakeIntermediate();
auto* dq1_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input1Type>(input1_arg,
.004f,
std::numeric_limits<Input1Type>::max() / 2,
q1_output);
builder.AddDequantizeLinearNode<Input1Type>(q1_output,
.0039f,
std::numeric_limits<Input1Type>::max() / 2,
dq1_output);

// add QDQ 2
auto* q2_output = builder.MakeIntermediate();
auto* dq2_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input2Type>(input2_arg,
.004f,
std::numeric_limits<Input2Type>::max() / 2,
q2_output);
builder.AddDequantizeLinearNode<Input2Type>(q2_output,
.0039f,
std::numeric_limits<Input2Type>::max() / 2,
dq2_output);

// add binary operator
auto* binary_op_output = builder.MakeIntermediate();
builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output});

// add QDQ output
auto* q3_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<OutputType>(binary_op_output,
.0038f,
std::numeric_limits<OutputType>::max() / 2,
q3_output);
builder.AddDequantizeLinearNode<OutputType>(q3_output,
.0039f,
std::numeric_limits<OutputType>::max() / 2,
output_arg);
};

auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
if (std::is_same<Input1Type, Input2Type>::value &&
Expand All @@ -351,7 +306,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) {
}
};

TransformerTester(build_test_case,
TransformerTester(BuildBinaryOpTestCase<Input1Type, Input2Type, OutputType>(input_shape, op_type),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
Expand Down Expand Up @@ -614,8 +569,7 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one

auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>)&&
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>)&&(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
EXPECT_EQ(op_to_count["Gemm"], 0);
Expand Down

0 comments on commit 7a32847

Please sign in to comment.