From ee217f2261512573e01af918ac9631bcf6ad5e9d Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Thu, 17 Nov 2022 15:07:23 +0300 Subject: [PATCH] [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract This commit adds compute/schedule implementation for Hexagon target for QNN ops: qnn.mul, qnn.subtract, qnn.tanh. It works only if QNN canonicalization pass was disabled. --- python/tvm/relay/qnn/op/_qnn.py | 11 +- python/tvm/relay/qnn/strategy/generic.py | 27 +++ python/tvm/relay/qnn/strategy/hexagon.py | 36 ++++ python/tvm/topi/hexagon/qnn/nn.py | 179 ++++++++++++++---- src/relay/qnn/op/add.cc | 3 +- src/relay/qnn/op/mul.cc | 3 +- src/relay/qnn/op/requantize.cc | 3 + src/relay/qnn/op/subtract.cc | 3 +- .../test_wo_qnn_canonicalization.py | 178 +++++++++++++---- 9 files changed, 362 insertions(+), 81 deletions(-) diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index 4e54583a3be0..64ef1ee92a1c 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -66,7 +66,16 @@ def simulated_dequantize_compute(attrs, inputs, output_type): # qnn.add register_strategy("qnn.add", strategy.qnn_add_strategy) -register_pattern("qnn.add", OpPattern.BROADCAST) + +# qnn.subtract +register_strategy("qnn.subtract", strategy.qnn_subtract_strategy) + +# qnn.mul +register_strategy("qnn.mul", strategy.qnn_mul_strategy) + +# qnn.tanh +register_strategy("qnn.tanh", strategy.qnn_tanh_strategy) +register_pattern("qnn.tanh", OpPattern.ELEMWISE) # qnn.concatenate register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 57a364f7e057..8275cf7f755e 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -213,6 +213,33 @@ def qnn_add_strategy(attrs, inputs, out_type, target): ) +@override_native_generic_func("qnn_subtract_strategy") +def qnn_subtract_strategy(attrs, inputs, out_type, target): + """qnn.subtract generic strategy""" + raise RuntimeError( + "qnn.subtract is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_mul_strategy") +def qnn_mul_strategy(attrs, inputs, out_type, target): + """qnn.mul generic strategy""" + raise RuntimeError( + "qnn.mul is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_tanh_strategy") +def qnn_tanh_strategy(attrs, inputs, out_type, target): + """qnn.tanh generic strategy""" + raise RuntimeError( + "qnn.tanh is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + @override_native_generic_func("qnn_concatenate_strategy") def qnn_concatenate_strategy(attrs, inputs, out_type, target): """qnn.concatenate generic strategy""" diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index c7f59cc096fc..d17812e3fbcc 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -71,6 +71,42 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): return strategy +@qnn_subtract_strategy.register("hexagon") +def qnn_subtract_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.subtract strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_subtract), + wrap_topi_schedule(topi.hexagon.schedule_qnn_subtract), + name="qnn_subtract.hexagon", + ) + return strategy + + +@qnn_mul_strategy.register("hexagon") +def qnn_mul_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.mul strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_mul), + wrap_topi_schedule(topi.hexagon.schedule_qnn_mul), + name="qnn_mul.hexagon", + ) + return strategy + + +@qnn_tanh_strategy.register("hexagon") +def qnn_tanh_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.tanh strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_tanh), + wrap_topi_schedule(topi.hexagon.schedule_qnn_tanh), + name="qnn_tanh.hexagon", + ) + return strategy + + @qnn_concatenate_strategy.register("hexagon") def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target): """qnn.concatenate strategy for Hexagon""" diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py index 40cfd0ee96b1..49220d0fd013 100644 --- a/python/tvm/topi/hexagon/qnn/nn.py +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -19,6 +19,7 @@ import tvm from tvm import te, topi +from ..utils import saturate from ...utils import get_const_tuple from ...nn.utils import get_pad_tuple from ...nn.pad import pad @@ -33,6 +34,11 @@ def clip_cast(val, dtype): return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) +# Return True if given Tensor is scalar constant value. +def is_constant(tensor: te.Tensor): + return tensor.ndim == 0 + + def get_qnn_param(param, indices, axis): # Account scalar and 1D quantization parameters: if len(param.shape) == 0: @@ -62,7 +68,7 @@ def default_schedule(outs): return s -def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): +def qnn_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): """Compute for qnn.quantize Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), @@ -101,7 +107,7 @@ def schedule_qnn_quantize(outs): return default_schedule(outs) -def qnn_dequantize(data, input_scale, input_zero_point, axis): +def qnn_dequantize(data, input_scale, input_zero_point, axis=-1): """Compute for qnn.dequantize fp_output = input_scale * (Q_input - input_zero_point) @@ -134,7 +140,7 @@ def schedule_qnn_dequantize(outs): return default_schedule(outs) -def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype): +def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis=-1, out_dtype="int8"): """Compute for qnn.requantize Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) @@ -177,37 +183,58 @@ def schedule_qnn_requantize(outs): return default_schedule(outs) -def qnn_add( - lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point +def compute_qnn_binary_op( + lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, func ): - """Compute for qnn.add + """Compute for QNN binary operation - Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input)) - + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input)) - - TODO: support 'axis' argument. + Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input - lhs_zp)) + _OP_ round((rhs_scale)/(output_scale) * (rhs_input - rhs_zp)) + where _OP_ is add/subtract """ - assert lhs.dtype == rhs.dtype dtype = lhs.dtype + def _compute_const(x: te.Tensor, iscale, input_zp): + return te.round(te.multiply(te.div(iscale, output_scale), te.subtract(x, input_zp))).astype( + "int32" + ) + + def _compute_tensor(x: te.Tensor, iscale, input_zp): + return te.compute( + x.shape, + lambda *i: te.round( + te.multiply(te.div(iscale, output_scale), te.subtract(x(*i), input_zp)) + ).astype("int32"), + ) + + if is_constant(lhs): + lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp) + else: + lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp) + + if is_constant(rhs): + rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp) + else: + rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp) + + # Binary op with broadcasting + tensor = func(lhs_tensor, rhs_tensor) + + # Add output zero point and clip+cast. def _compute(*indices): - lvalue = lhs(*indices) - rvalue = rhs(*indices) - q_lv = te.round( - te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point)) - ).astype("int32") - q_rv = te.round( - te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point)) - ).astype("int32") - val = te.add(te.add(q_lv, q_rv), output_zero_point) + return saturate(te.add(tensor(*indices), output_zp), dtype).astype(dtype) + + return te.compute(tensor.shape, _compute) - # clip + cast: - const_min = tvm.tir.min_value(dtype) - const_max = tvm.tir.max_value(dtype) - return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) - return te.compute(lhs.shape, _compute) +def qnn_add(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp): + """Compute for qnn.add + TODO: support 'axis' argument. + """ + return compute_qnn_binary_op( + lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.add + ) def schedule_qnn_add(outs): @@ -227,19 +254,99 @@ def schedule_qnn_add(outs): return default_schedule(outs) -def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype): - """Requantize tensor""" +def qnn_subtract(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp): + """Compute for qnn.subtract""" - def _compute(*indices): - value = tensor(*indices) - mul_value = te.round( - te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp)) - ).astype("int32") - rq_value = te.add(mul_value, o_zp) + return compute_qnn_binary_op( + lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.subtract + ) - return clip_cast(rq_value, out_dtype) - return te.compute(tensor.shape, _compute) +def schedule_qnn_subtract(outs): + """Schedule for qnn.subtract + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp): + """Compute for qnn.mul + + mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp) + Q_output = requantize(mul, lhs_scale * rhs_scale, 0, output_scale, output_zp) + """ + assert lhs.dtype == rhs.dtype + odtype = lhs.dtype + + if is_constant(lhs): + lhs_tensor = lhs - lhs_zp + else: + lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp)) + + if is_constant(rhs): + rhs_tensor = rhs - rhs_zp + else: + rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp)) + + # Multiply with broadcasting. + mul = topi.multiply(lhs_tensor, rhs_tensor) + + iscale = lhs_scale * rhs_scale + return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype) + + +def schedule_qnn_mul(outs): + """Schedule for qnn.mul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + +def qnn_tanh(data, input_scale, input_zp, output_scale, output_zp): + """Compute for qnn.tanh + + Q_output = quantize(tanh(dequantize(data))) + """ + dq_tensor = qnn_dequantize(data, input_scale, input_zp) + tanh = te.compute(dq_tensor.shape, lambda *i: te.tanh(dq_tensor(*i))) + return qnn_quantize(tanh, output_scale, output_zp, out_dtype=data.dtype) + + +def schedule_qnn_tanh(outs): + """Schedule for qnn.tanh + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) def qnn_concatenate(data, axis, out_dtype): @@ -282,7 +389,7 @@ def qnn_concatenate(data, axis, out_dtype): i_zp = data[i + args_num * 2] # Requantize tensors and add them to the list. - args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype)) + args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype=out_dtype)) # Call x86 implementation of concatenate. return concatenate(args, axis) diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index d087d9fa7796..0e0d3fdbc0dd 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -96,7 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, QNN_REGISTER_BINARY_OP("add") .describe("Elementwise add with broadcasting for quantized tensors.") .set_support_level(11) - .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); + .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize) + .set_attr("TOpPattern", kBroadcast); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index 6dde61359df6..73c6eed44889 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -162,7 +162,8 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, QNN_REGISTER_BINARY_OP("mul") .describe("Elementwise mul with broadcasting for quantized tensors.") .set_support_level(11) - .set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); + .set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize) + .set_attr("TOpPattern", kBroadcast); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 1614652719c6..336daa2fcc4c 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -384,6 +384,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { + // Check output scale validity. + ICHECK_NE(GetScalarFromConstant(output_scale), 0.0) + << "QNN requantize output scale can not be equal to 0.0"; // Check rounding validity. ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") << "QNN requantize supports two rounding modes - UPWARD and " diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc index 181501922086..962a3434cb72 100644 --- a/src/relay/qnn/op/subtract.cc +++ b/src/relay/qnn/op/subtract.cc @@ -97,7 +97,8 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, QNN_REGISTER_BINARY_OP("subtract") .describe("Elementwise subtract with broadcasting for quantized tensors.") .set_support_level(11) - .set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); + .set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize) + .set_attr("TOpPattern", kBroadcast); } // namespace qnn } // namespace relay diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py index e4edf2919a00..06e738d9b70e 100644 --- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -51,13 +51,33 @@ def test_no_qnn_pass(): assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False) -def execute(executor, data_np, weight_np, bias_np=None): - executor.set_input("data", data_np) - executor.set_input("weight", weight_np) - if bias_np is not None: - executor.set_input("bias", bias_np) - executor.run() - return executor.get_output(0) +def execute(mod_executor, inputs: dict): + for input_name, input_data in inputs.items(): + mod_executor.set_input(input_name, input_data) + mod_executor.run() + return mod_executor.get_output(0).numpy() + + +def build_hexagon_module(mod): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): + hexagon_lowered = tvm.relay.build( + mod, + tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET), + executor=Executor("aot"), + ) + + return hexagon_lowered + + +def build_ref_module(mod): + target_llvm = tvm.target.Target("llvm") + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, + tvm.target.Target(target_llvm, host=target_llvm), + executor=Executor("aot"), + ) + return llvm_lowered @tvm.testing.requires_hexagon @@ -90,33 +110,24 @@ def test_qnn_conv2d_rq(hexagon_session: Session): ) relay_mod = tvm.IRModule.from_expr(op5) - target_llvm = tvm.target.Target("llvm") - executor = Executor("aot") - with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): - hexagon_lowered = tvm.relay.build( - relay_mod, - tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET), - executor=executor, - ) + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(relay_mod) - with tvm.transform.PassContext(opt_level=3): - llvm_lowered = tvm.relay.build( - relay_mod, - tvm.target.Target(target_llvm, host=target_llvm), - executor=executor, - ) + # Reference compilation + llvm_lowered = build_ref_module(relay_mod) data_np = np.random.rand(*data_shape) - 0.5 weight_np = np.random.rand(*weight_shape) - 0.5 + inputs = {"data": data_np, "weight": weight_np} hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) - hexagon_output = execute(hx_m, data_np, weight_np) + hexagon_output = execute(hx_m, inputs) dev = tvm.cpu(0) llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev)) - llvm_out = execute(llvm_m, data_np, weight_np) + llvm_out = execute(llvm_m, inputs) - np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + np.testing.assert_equal(hexagon_output, llvm_out) @tvm.testing.requires_hexagon @@ -152,34 +163,119 @@ def test_qnn_dense_bias_rq(hexagon_session: Session): ) relay_mod = tvm.IRModule.from_expr(op5) - target_llvm = tvm.target.Target("llvm") - executor = Executor("aot") - with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): - hexagon_lowered = tvm.relay.build( - relay_mod, - tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET), - executor=executor, - ) + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(relay_mod) - with tvm.transform.PassContext(opt_level=3): - llvm_lowered = tvm.relay.build( - relay_mod, - tvm.target.Target(target_llvm, host=target_llvm), - executor=executor, - ) + # Reference compilation + llvm_lowered = build_ref_module(relay_mod) data_np = np.random.rand(*data_shape) - 0.5 weight_np = np.random.rand(*weight_shape) - 0.5 bias_np = np.random.rand(*bias_shape) + inputs = {"data": data_np, "weight": weight_np, "bias": bias_np} hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) - hexagon_output = execute(hx_m, data_np, weight_np, bias_np) + hexagon_output = execute(hx_m, inputs) dev = tvm.cpu(0) llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev)) - llvm_out = execute(llvm_m, data_np, weight_np, bias_np) + llvm_out = execute(llvm_m, inputs) + + np.testing.assert_equal(hexagon_output, llvm_out) + + +class TestQnnBinaryOp: + """QNN binary op test class""" + + operation = tvm.testing.parameter( + relay.qnn.op.add, + relay.qnn.op.subtract, + relay.qnn.op.mul, + ) + dtype = tvm.testing.parameter("uint8", "int8") + input_shape = tvm.testing.parameter([256], [4, 256]) + + @tvm.testing.requires_hexagon + def test_qnn_binary_op_broadcasting( + self, hexagon_session: Session, operation, dtype, input_shape + ): + """qnn binary op test without QNN canonicalization.""" + lhs_shape = [4, 256] + rhs_shape = input_shape + lhs = relay.var("lhs", shape=lhs_shape, dtype=dtype) + rhs = relay.var("rhs", shape=rhs_shape, dtype=dtype) + zp_const1 = 1 + zp_const2 = 3 + + op = operation( + lhs, + rhs, + lhs_scale=relay.const(0.041, "float32"), + lhs_zero_point=relay.const(zp_const1, "int32"), + rhs_scale=relay.const(0.017, "float32"), + rhs_zero_point=relay.const(zp_const2, "int32"), + output_scale=relay.const(0.039, "float32"), + output_zero_point=relay.const(2, "int32"), + ) + mod = tvm.IRModule.from_expr(op) + + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(mod) + + # Reference compilation + llvm_lowered = build_ref_module(mod) + + lhs_np = np.random.randint(np.iinfo(dtype).min + zp_const1, np.iinfo(dtype).max, lhs_shape) + rhs_np = np.random.randint(np.iinfo(dtype).min + zp_const2, np.iinfo(dtype).max, rhs_shape) + inputs = {"lhs": lhs_np, "rhs": rhs_np} + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, inputs) + + dev = tvm.cpu(0) + llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev)) + llvm_output = execute(llvm_m, inputs) + + # Diff by 1 is Ok. + tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1) + + @tvm.testing.requires_hexagon + def test_qnn_binary_op_scalar(self, hexagon_session: Session, operation): + """qnn binary op test without QNN canonicalization.""" + lhs_shape = [4, 256] + lhs = relay.var("lhs", shape=lhs_shape, dtype="uint8") + rhs = relay.const(11, dtype="uint8") + + op = operation( + lhs, + rhs, + lhs_scale=relay.const(0.049, "float32"), + lhs_zero_point=relay.const(1, "int32"), + rhs_scale=relay.const(0.067, "float32"), + rhs_zero_point=relay.const(3, "int32"), + output_scale=relay.const(0.041, "float32"), + output_zero_point=relay.const(2, "int32"), + ) + mod = tvm.IRModule.from_expr(op) + + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(mod) + + # Reference compilation + llvm_lowered = build_ref_module(mod) + + lhs_np = np.random.randint(1, 255, size=lhs_shape) + inputs = {"lhs": lhs_np} + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, inputs) + + dev = tvm.cpu(0) + llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev)) + llvm_output = execute(llvm_m, inputs) - np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + # Diff by 1 is Ok. + tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1) if __name__ == "__main__":