diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d185d143c7a6..62f0f4b2dd25 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1406,7 +1406,10 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] *= _expr.const(alpha, dtype=dtype) out = _op.nn.dense(inputs[0], inputs[1], units=channels) if len(inputs) == 3: - out = out + _expr.const(beta, dtype=dtype) * inputs[2] + if beta != 1.0: + out += _expr.const(float(beta), dtype=dtype) * inputs[2] + else: + out += inputs[2] return out diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 46bdd94ace1a..84b1f33e98cc 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -502,6 +502,37 @@ def register_binary_qnn(op_name, op): def binary(expr, type_map): left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + + if ( + op_name == "add" + and approx_equal(left_t.scale, right_t.scale) + and approx_equal(left_t.zero_point, right_t.zero_point) + and tvm.ir.structural_equal(left_t.dtype, right_t.dtype) + and left_t.dtype == "int32" + and approx_equal(left_t.scale, out_t.scale) + and approx_equal(left_t.zero_point, out_t.zero_point) + and np.all(out_t.zero_point.data.numpy() == 0) + ): + # If this add op comes after conv2d or dense, out_t.scale and out_t.zero_point + # can be a vector, which is not supported by QNN binary operators. + # In particular, the pattern of an `add` op following `dense`, where the addition is + # really a bias addtion, can come up often. We identify that pattern and convert it to + # `qnn.dense` -> `add`. + # To avoid overflow, we do this conversion only when the input data type is 32 bit (bias + # addition is typically done in 32 bit). + return [left + right, left_t] + + assert ( + len(out_t.scale.data.shape) == 0 + ), "The output scale needs to be a scalar, but got a tensor of shape {}".format( + out_t.scale.data.shape + ) + assert ( + len(out_t.zero_point.data.shape) == 0 + ), "The output zero point needs to be a scalar, but got a tensor of shape {}".format( + out_t.zero_point.data.shape + ) + out = op( left, right, diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index eb176df5c978..31353d5aa25e 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -193,7 +193,7 @@ class SubgraphMutator : public ExprMutator { return Mutate(expr); } catch (std::exception& e) { if (hard_fail_) { - throw e; + LOG(FATAL) << e.what(); } else { DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl; return expr; diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 569bd9d7d653..d384635e42e5 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -154,6 +154,41 @@ def test_fake_quantize_dense_per_channel(): compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) +def test_fake_quantize_dense_bias(): + out_dtype = "int8" + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + bias = relay.var("bias", shape=[256], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + w_scale = np.random.random([256]).astype("float32") + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize( + w, + relay.const(w_scale), + zero, + axis=0, + ), + units=256, + ) + + op += relay.qnn.op.dequantize( + bias, + relay.const(2.0 * w_scale), + zero, + ) + + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + bias_np = np.random.randint(-128, 127, size=[256], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np], allow_rounding_error=True) + + def test_fake_quantize_batch_matmul(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 128, 64], dtype="int8") @@ -976,15 +1011,9 @@ def test_fq_qat_positive_nothing_to_do(): op1 = relay.qnn.op.quantize( relay.const(1.0), relay.const(12.0), relay.const(0), out_dtype="int32" ) - op2 = relay.qnn.op.add( + op2 = relay.op.add( op0, op1, - relay.const(12.0), - relay.const(0), - relay.const(12.0), - relay.const(0), - relay.const(12.0), - relay.const(0), ) expected_expr = relay.qnn.op.requantize( op2, relay.const(12.0), relay.const(0), relay.const(1.0), relay.const(0), out_dtype="int8"