Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract #13416

Merged
merged 1 commit into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/tvm/relay/qnn/op/_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@ def simulated_dequantize_compute(attrs, inputs, output_type):

# qnn.add
register_strategy("qnn.add", strategy.qnn_add_strategy)
register_pattern("qnn.add", OpPattern.BROADCAST)

# qnn.subtract
register_strategy("qnn.subtract", strategy.qnn_subtract_strategy)

# qnn.mul
register_strategy("qnn.mul", strategy.qnn_mul_strategy)

# qnn.tanh
register_strategy("qnn.tanh", strategy.qnn_tanh_strategy)
register_pattern("qnn.tanh", OpPattern.ELEMWISE)

# qnn.concatenate
register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy)
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/qnn/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,33 @@ def qnn_add_strategy(attrs, inputs, out_type, target):
)


@override_native_generic_func("qnn_subtract_strategy")
def qnn_subtract_strategy(attrs, inputs, out_type, target):
"""qnn.subtract generic strategy"""
raise RuntimeError(
"qnn.subtract is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_mul_strategy")
def qnn_mul_strategy(attrs, inputs, out_type, target):
"""qnn.mul generic strategy"""
raise RuntimeError(
"qnn.mul is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_tanh_strategy")
def qnn_tanh_strategy(attrs, inputs, out_type, target):
"""qnn.tanh generic strategy"""
raise RuntimeError(
"qnn.tanh is currently only supported with Hexagon. "
"Please run QNN Canonicalize pass to decompose this op into supported ops."
)


@override_native_generic_func("qnn_concatenate_strategy")
def qnn_concatenate_strategy(attrs, inputs, out_type, target):
"""qnn.concatenate generic strategy"""
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/qnn/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,42 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target):
return strategy


@qnn_subtract_strategy.register("hexagon")
def qnn_subtract_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.subtract strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_subtract),
wrap_topi_schedule(topi.hexagon.schedule_qnn_subtract),
name="qnn_subtract.hexagon",
)
return strategy


@qnn_mul_strategy.register("hexagon")
def qnn_mul_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.mul strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_mul),
wrap_topi_schedule(topi.hexagon.schedule_qnn_mul),
name="qnn_mul.hexagon",
)
return strategy


@qnn_tanh_strategy.register("hexagon")
def qnn_tanh_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.tanh strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_topi_compute(topi.hexagon.qnn_tanh),
wrap_topi_schedule(topi.hexagon.schedule_qnn_tanh),
name="qnn_tanh.hexagon",
)
return strategy


@qnn_concatenate_strategy.register("hexagon")
def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.concatenate strategy for Hexagon"""
Expand Down
179 changes: 143 additions & 36 deletions python/tvm/topi/hexagon/qnn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tvm
from tvm import te, topi
from ..utils import saturate
from ...utils import get_const_tuple
from ...nn.utils import get_pad_tuple
from ...nn.pad import pad
Expand All @@ -33,6 +34,11 @@ def clip_cast(val, dtype):
return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)


# Return True if given Tensor is scalar constant value.
def is_constant(tensor: te.Tensor):
return tensor.ndim == 0


def get_qnn_param(param, indices, axis):
# Account scalar and 1D quantization parameters:
if len(param.shape) == 0:
Expand Down Expand Up @@ -62,7 +68,7 @@ def default_schedule(outs):
return s


def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype):
def qnn_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
"""Compute for qnn.quantize

Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
Expand Down Expand Up @@ -101,7 +107,7 @@ def schedule_qnn_quantize(outs):
return default_schedule(outs)


def qnn_dequantize(data, input_scale, input_zero_point, axis):
def qnn_dequantize(data, input_scale, input_zero_point, axis=-1):
"""Compute for qnn.dequantize

fp_output = input_scale * (Q_input - input_zero_point)
Expand Down Expand Up @@ -134,7 +140,7 @@ def schedule_qnn_dequantize(outs):
return default_schedule(outs)


def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype):
def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis=-1, out_dtype="int8"):
"""Compute for qnn.requantize

Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input))
Expand Down Expand Up @@ -177,37 +183,58 @@ def schedule_qnn_requantize(outs):
return default_schedule(outs)


def qnn_add(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
def compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, func
):
"""Compute for qnn.add
"""Compute for QNN binary operation

Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input))
+ round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input))

TODO: support 'axis' argument.
Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input - lhs_zp))
_OP_ round((rhs_scale)/(output_scale) * (rhs_input - rhs_zp))
where _OP_ is add/subtract
"""

assert lhs.dtype == rhs.dtype
dtype = lhs.dtype

def _compute_const(x: te.Tensor, iscale, input_zp):
return te.round(te.multiply(te.div(iscale, output_scale), te.subtract(x, input_zp))).astype(
"int32"
)

def _compute_tensor(x: te.Tensor, iscale, input_zp):
return te.compute(
x.shape,
lambda *i: te.round(
te.multiply(te.div(iscale, output_scale), te.subtract(x(*i), input_zp))
).astype("int32"),
)

if is_constant(lhs):
lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp)
else:
lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp)

if is_constant(rhs):
rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp)
else:
rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp)

# Binary op with broadcasting
tensor = func(lhs_tensor, rhs_tensor)

# Add output zero point and clip+cast.
def _compute(*indices):
lvalue = lhs(*indices)
rvalue = rhs(*indices)
q_lv = te.round(
te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point))
).astype("int32")
q_rv = te.round(
te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point))
).astype("int32")
val = te.add(te.add(q_lv, q_rv), output_zero_point)
return saturate(te.add(tensor(*indices), output_zp), dtype).astype(dtype)

return te.compute(tensor.shape, _compute)

# clip + cast:
const_min = tvm.tir.min_value(dtype)
const_max = tvm.tir.max_value(dtype)
return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)

return te.compute(lhs.shape, _compute)
def qnn_add(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.add
TODO: support 'axis' argument.
"""
return compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.add
)


def schedule_qnn_add(outs):
Expand All @@ -227,19 +254,99 @@ def schedule_qnn_add(outs):
return default_schedule(outs)


def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype):
"""Requantize tensor"""
def qnn_subtract(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.subtract"""

def _compute(*indices):
value = tensor(*indices)
mul_value = te.round(
te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp))
).astype("int32")
rq_value = te.add(mul_value, o_zp)
return compute_qnn_binary_op(
lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.subtract
)

return clip_cast(rq_value, out_dtype)

return te.compute(tensor.shape, _compute)
def schedule_qnn_subtract(outs):
"""Schedule for qnn.subtract

Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
"""Compute for qnn.mul

mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp)
Q_output = requantize(mul, lhs_scale * rhs_scale, 0, output_scale, output_zp)
"""
assert lhs.dtype == rhs.dtype
odtype = lhs.dtype

if is_constant(lhs):
lhs_tensor = lhs - lhs_zp
else:
lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp))

if is_constant(rhs):
rhs_tensor = rhs - rhs_zp
else:
rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp))

# Multiply with broadcasting.
mul = topi.multiply(lhs_tensor, rhs_tensor)

iscale = lhs_scale * rhs_scale
return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype)


def schedule_qnn_mul(outs):
"""Schedule for qnn.mul

Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_tanh(data, input_scale, input_zp, output_scale, output_zp):
"""Compute for qnn.tanh

Q_output = quantize(tanh(dequantize(data)))
"""
dq_tensor = qnn_dequantize(data, input_scale, input_zp)
tanh = te.compute(dq_tensor.shape, lambda *i: te.tanh(dq_tensor(*i)))
return qnn_quantize(tanh, output_scale, output_zp, out_dtype=data.dtype)


def schedule_qnn_tanh(outs):
"""Schedule for qnn.tanh

Parameters
----------
outs: Array of Tensor
The computation graph description of qnn.add
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return default_schedule(outs)


def qnn_concatenate(data, axis, out_dtype):
Expand Down Expand Up @@ -282,7 +389,7 @@ def qnn_concatenate(data, axis, out_dtype):
i_zp = data[i + args_num * 2]

# Requantize tensors and add them to the list.
args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype))
args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype=out_dtype))

# Call x86 implementation of concatenate.
return concatenate(args, axis)
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("add")
.describe("Elementwise add with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
3 changes: 3 additions & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
// Check output scale validity.
ICHECK_NE(GetScalarFromConstant<float>(output_scale), 0.0)
<< "QNN requantize output scale can not be equal to 0.0";
// Check rounding validity.
ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/op/subtract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("subtract")
.describe("Elementwise subtract with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

} // namespace qnn
} // namespace relay
Expand Down
Loading