Skip to content

Commit

Permalink
[Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract
Browse files Browse the repository at this point in the history
This commit adds compute/schedule implementation for Hexagon target for
QNN ops: qnn.mul, qnn.subtract, qnn.matmul. It works only if QNN
canonicalization pass was disabled.
  • Loading branch information
ibsidorenko committed Nov 17, 2022
1 parent f9ed60a commit f8743c4
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 81 deletions.
11 changes: 10 additions & 1 deletion python/tvm/relay/qnn/op/_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@ def simulated_dequantize_compute(attrs, inputs, output_type):

# qnn.add
register_strategy("qnn.add", strategy.qnn_add_strategy)
register_pattern("qnn.add", OpPattern.BROADCAST)

# qnn.subtract
register_strategy("qnn.subtract", strategy.qnn_subtract_strategy)

# qnn.mul
register_strategy("qnn.mul", strategy.qnn_mul_strategy)

# qnn.tanh
register_strategy("qnn.tanh", strategy.qnn_tanh_strategy)
register_pattern("qnn.tanh", OpPattern.ELEMWISE)

# qnn.concatenate
register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy)
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/qnn/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,33 @@ def qnn_add_strategy(attrs, inputs, out_type, target):
)


@override_native_generic_func("qnn_subtract_strategy")
def qnn_subtract_strategy(attrs, inputs, out_type, target):
"""qnn.subtract generic strategy"""
raise RuntimeError(
"qnn.subtract is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_mul_strategy")
def qnn_mul_strategy(attrs, inputs, out_type, target):
"""qnn.mul generic strategy"""
raise RuntimeError(
"qnn.mul is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_tanh_strategy")
def qnn_tanh_strategy(attrs, inputs, out_type, target):
"""qnn.tanh generic strategy"""
raise RuntimeError(
"qnn.tanh is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_concatenate_strategy")
def qnn_concatenate_strategy(attrs, inputs, out_type, target):
"""qnn.concatenate generic strategy"""
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/qnn/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,42 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target):
return strategy


@qnn_subtract_strategy.register("hexagon")
def qnn_subtract_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.subtract strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_subtract),
wrap_topi_schedule(topi.hexagon.schedule_qnn_subtract),
name="qnn_subtract.hexagon",
)
return strategy


@qnn_mul_strategy.register("hexagon")
def qnn_mul_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.mul strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_mul),
wrap_topi_schedule(topi.hexagon.schedule_qnn_mul),
name="qnn_mul.hexagon",
)
return strategy


@qnn_tanh_strategy.register("hexagon")
def qnn_tanh_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.tanh strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_tanh),
wrap_topi_schedule(topi.hexagon.schedule_qnn_tanh),
name="qnn_tanh.hexagon",
)
return strategy


@qnn_concatenate_strategy.register("hexagon")
def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.concatenate strategy for Hexagon"""
Expand Down
179 changes: 143 additions & 36 deletions python/tvm/topi/hexagon/qnn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tvm
from tvm import te, topi
from ..utils import saturate
from ...utils import get_const_tuple
from ...nn.utils import get_pad_tuple
from ...nn.pad import pad
Expand All @@ -33,6 +34,11 @@ def clip_cast(val, dtype):
return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)


# Return True if given Tensor is scalar constant value.
def is_constant(tensor: te.Tensor):
return tensor.ndim == 0


def get_qnn_param(param, indices, axis):
# Account scalar and 1D quantization parameters:
if len(param.shape) == 0:
Expand Down Expand Up @@ -62,7 +68,7 @@ def default_schedule(outs):
return s


def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype):
def qnn_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
"""Compute for qnn.quantize
Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
Expand Down Expand Up @@ -101,7 +107,7 @@ def schedule_qnn_quantize(outs):
return default_schedule(outs)


def qnn_dequantize(data, input_scale, input_zero_point, axis):
def qnn_dequantize(data, input_scale, input_zero_point, axis=-1):
"""Compute for qnn.dequantize
fp_output = input_scale * (Q_input - input_zero_point)
Expand Down Expand Up @@ -134,7 +140,7 @@ def schedule_qnn_dequantize(outs):
return default_schedule(outs)


def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype):
def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis=-1, out_dtype="int8"):
"""Compute for qnn.requantize
Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input))
Expand Down Expand Up @@ -177,37 +183,58 @@ def schedule_qnn_requantize(outs):
return default_schedule(outs)


def qnn_add(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
def compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, func
):
"""Compute for qnn.add
"""Compute for QNN binary operation
Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input))
+ round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input))
TODO: support 'axis' argument.
Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input - lhs_zp))
_OP_ round((rhs_scale)/(output_scale) * (rhs_input - rhs_zp))
where _OP_ is add/subtract
"""

assert lhs.dtype == rhs.dtype
dtype = lhs.dtype

def _compute_const(x: te.Tensor, iscale, input_zp):
return te.round(te.multiply(te.div(iscale, output_scale), te.subtract(x, input_zp))).astype(
"int32"
)

def _compute_tensor(x: te.Tensor, iscale, input_zp):
return te.compute(
x.shape,
lambda *i: te.round(
te.multiply(te.div(iscale, output_scale), te.subtract(x(*i), input_zp))
).astype("int32"),
)

if is_constant(lhs):
lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp)
else:
lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp)

if is_constant(rhs):
rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp)
else:
rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp)

# Binary op with broadcasting
tensor = func(lhs_tensor, rhs_tensor)

# Add output zero point and clip+cast.
def _compute(*indices):
lvalue = lhs(*indices)
rvalue = rhs(*indices)
q_lv = te.round(
te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point))
).astype("int32")
q_rv = te.round(
te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point))
).astype("int32")
val = te.add(te.add(q_lv, q_rv), output_zero_point)
return saturate(te.add(tensor(*indices), output_zp), dtype).astype(dtype)

return te.compute(tensor.shape, _compute)

# clip + cast:
const_min = tvm.tir.min_value(dtype)
const_max = tvm.tir.max_value(dtype)
return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)

return te.compute(lhs.shape, _compute)
def qnn_add(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.add
TODO: support 'axis' argument.
"""
return compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.add
)


def schedule_qnn_add(outs):
Expand All @@ -227,19 +254,99 @@ def schedule_qnn_add(outs):
return default_schedule(outs)


def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype):
"""Requantize tensor"""
def qnn_subtract(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.subtract"""

def _compute(*indices):
value = tensor(*indices)
mul_value = te.round(
te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp))
).astype("int32")
rq_value = te.add(mul_value, o_zp)
return compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.subtract
)

return clip_cast(rq_value, out_dtype)

return te.compute(tensor.shape, _compute)
def schedule_qnn_subtract(outs):
"""Schedule for qnn.subtract
Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.mul
mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp)
Q_output = requantize(mul, lhs_scale * rhs_scale, 0, output_scale, output_zp)
"""
assert lhs.dtype == rhs.dtype
odtype = lhs.dtype

if is_constant(lhs):
lhs_tensor = lhs - lhs_zp
else:
lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp))

if is_constant(rhs):
rhs_tensor = rhs - rhs_zp
else:
rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp))

# Multiply with broadcasting.
mul = topi.multiply(lhs_tensor, rhs_tensor)

iscale = lhs_scale * rhs_scale
return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype)


def schedule_qnn_mul(outs):
"""Schedule for qnn.mul
Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_tanh(data, input_scale, input_zp, output_scale, output_zp):
"""Compute for qnn.tanh
Q_output = quantize(tanh(dequantize(data)))
"""
dq_tensor = qnn_dequantize(data, input_scale, input_zp)
tanh = te.compute(dq_tensor.shape, lambda *i: te.tanh(dq_tensor(*i)))
return qnn_quantize(tanh, output_scale, output_zp, out_dtype=data.dtype)


def schedule_qnn_tanh(outs):
"""Schedule for qnn.tanh
Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_concatenate(data, axis, out_dtype):
Expand Down Expand Up @@ -282,7 +389,7 @@ def qnn_concatenate(data, axis, out_dtype):
i_zp = data[i + args_num * 2]

# Requantize tensors and add them to the list.
args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype))
args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype=out_dtype))

# Call x86 implementation of concatenate.
return concatenate(args, axis)
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("add")
.describe("Elementwise add with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
3 changes: 3 additions & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
// Check output scale validity.
ICHECK_NE(GetScalarFromConstant<float>(output_scale), 0.0)
<< "QNN requantize output scale can not be equal to 0.0";
// Check rounding validity.
ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/subtract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("subtract")
.describe("Elementwise subtract with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
Loading

0 comments on commit f8743c4

Please sign in to comment.