Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Add int16 add and mul operator support #13920

Merged
merged 1 commit into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,37 @@ def check_qnn_binary_op(pattern):

arg0 = binary_op.args[0]
arg1 = binary_op.args[1]
both_args_scalar = False

# Check arguments are not scalar.
if (
isinstance(arg0, tvm.relay.expr.Constant)
and len(arg0.checked_type.shape) == 0
and isinstance(arg1, tvm.relay.expr.Constant)
and len(arg1.checked_type.shape) == 0
):
both_args_scalar = True
return False

return (
arg0.checked_type.dtype == "int8"
and arg1.checked_type.dtype == "int8"
and not both_args_scalar
)
arg0_type = arg0.checked_type.dtype
arg1_type = arg1.checked_type.dtype

# Check arguments are of valid type.
if arg0_type not in ["int8", "int16"]:
return False

# Check arguments are the same type.
if arg0_type != arg1_type:
return False

# Check zero points are non-zero (arm_elementwise_(add|mul)_s16 does not
# handle non-zero zero points).
if arg0_type == "int16" and str(binary_op.op.name) in ["qnn.add", "qnn.mul"]:
arg_0_zero_point = binary_op.args[3].data.numpy()
arg_1_zero_point = binary_op.args[5].data.numpy()
output_zero_point = binary_op.args[7].data.numpy()
if any([arg_0_zero_point, arg_1_zero_point, output_zero_point]):
return False

return True

return [
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad),
Expand Down
104 changes: 40 additions & 64 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ class RelayToTIRVisitor : public MixedModeMutator {
ir_module_->Add(global_var, replacement_func);
}

auto GetIntMinMax(int bit_width) {
const int32_t min =
(bit_width == 8) ? std::numeric_limits<int8_t>::min() : std::numeric_limits<int16_t>::min();
const int32_t max =
(bit_width == 8) ? std::numeric_limits<int8_t>::max() : std::numeric_limits<int16_t>::max();
return std::pair(min, max);
}

auto GetClipMinMax(const ClipAttrs* clip_attrs) {
return std::pair(static_cast<int32_t>(clip_attrs->a_min),
static_cast<int32_t>(clip_attrs->a_max));
}

auto GetClipMinMax(const Call& clip_op) { return GetClipMinMax(clip_op->attrs.as<ClipAttrs>()); }

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
Expand Down Expand Up @@ -190,18 +205,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]);
int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels);
std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
int32_t clip_min = std::numeric_limits<int8_t>::min();
int32_t clip_max = std::numeric_limits<int8_t>::max();

if (dtype_bits == 16) {
clip_min = std::numeric_limits<int16_t>::min();
clip_max = std::numeric_limits<int16_t>::max();
}
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
}
const auto [clip_min, clip_max] =
clip_call ? GetClipMinMax(GetRef<Call>(clip_call)) : GetIntMinMax(dtype_bits);

tvm::Array<PrimExpr> scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w),
ToArg(stride_h), ToArg(padding_w), ToArg(padding_h),
Expand Down Expand Up @@ -348,20 +353,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]);
float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
int32_t out_channels = qnn::get_const_int(dense_attrs->units);
int32_t clip_min, clip_max;
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
if (dtype_bits == 8) {
clip_min = std::numeric_limits<int8_t>::min();
clip_max = std::numeric_limits<int8_t>::max();
} else {
clip_min = std::numeric_limits<int16_t>::min();
clip_max = std::numeric_limits<int16_t>::max();
}
}
const auto [clip_min, clip_max] =
clip_call ? GetClipMinMax(GetRef<Call>(clip_call)) : GetIntMinMax(dtype_bits);

double quantized_multiplier =
static_cast<double>(input_scale) / static_cast<double>(output_scale);
Expand Down Expand Up @@ -432,7 +425,6 @@ class RelayToTIRVisitor : public MixedModeMutator {

// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
if (dtype_bits == 8) {
Expand Down Expand Up @@ -463,19 +455,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
}
if (clip.defined()) {
const ClipAttrs* clip_attrs = clip->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
if (dtype_bits == 8) {
clip_min = std::numeric_limits<int8_t>::min();
clip_max = std::numeric_limits<int8_t>::max();
} else {
clip_min = std::numeric_limits<int16_t>::min();
clip_max = std::numeric_limits<int16_t>::max();
}
}

const auto [clip_min, clip_max] =
clip.defined() ? GetClipMinMax(clip) : GetIntMinMax(dtype_bits);

tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)};
Expand Down Expand Up @@ -587,15 +569,11 @@ class RelayToTIRVisitor : public MixedModeMutator {
}

void EmitMul(const GlobalVar& global_var, const Expr& expr) {
int32_t output_min = std::numeric_limits<int8_t>::min();
int32_t output_max = std::numeric_limits<int8_t>::max();
const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr);
Call mul_call = pattern.binary_op;
if (pattern.clip_op) {
const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as<ClipAttrs>();
output_min = clip_attrs->a_min;
output_max = clip_attrs->a_max;
}
const auto bit_width = mul_call->type_as<TensorTypeNode>()->dtype.bits();
const auto [output_min, output_max] =
pattern.clip_op ? GetClipMinMax(pattern.clip_op.value()) : GetIntMinMax(bit_width);

const float input_0_scale = GetScalarFromConstant<float>(mul_call->args[2]);
const int32_t input_0_zero_point = GetScalarFromConstant<int32_t>(mul_call->args[3]);
Expand All @@ -614,17 +592,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width));
tir::Var input_1;
if (mul_call->args[0].same_as(mul_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
tir::StringImm("arm_elementwise_mul_s" + std::to_string(bit_width)),
input_0,
input_1,
ToArg(-input_0_zero_point),
Expand All @@ -643,15 +621,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
}

void EmitAdd(const GlobalVar& global_var, const Expr& expr) {
int32_t output_min = std::numeric_limits<int8_t>::min();
int32_t output_max = std::numeric_limits<int8_t>::max();
const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr);
Call add_call = pattern.binary_op;
if (pattern.clip_op) {
const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as<ClipAttrs>();
output_min = clip_attrs->a_min;
output_max = clip_attrs->a_max;
}
const auto bit_width = add_call->type_as<TensorTypeNode>()->dtype.bits();

const auto [output_min, output_max] =
pattern.clip_op ? GetClipMinMax(pattern.clip_op.value()) : GetIntMinMax(bit_width);

const float input_0_scale = GetScalarFromConstant<float>(add_call->args[2]);
const int32_t input_0_zero_point = GetScalarFromConstant<int32_t>(add_call->args[3]);
Expand All @@ -660,9 +635,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
const float output_scale = GetScalarFromConstant<float>(add_call->args[6]);
const int32_t output_zero_point = GetScalarFromConstant<int32_t>(add_call->args[7]);

const int32_t left_shift = 20;
const int32_t left_shift = (bit_width == 16) ? 15 : 20;
const int32_t input_0_offset = -input_0_zero_point;
const int32_t input_1_offset = -input_1_zero_point;
const int32_t output_offset = output_zero_point;

const float max_input_scale = std::max(input_0_scale, input_1_scale);
const double twice_max_input_scale = 2 * static_cast<double>(max_input_scale);
Expand All @@ -689,17 +665,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
PrimExpr tensor_size = add_call->type_as<TensorTypeNode>()->Size();

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width));
tir::Var input_1;
if (add_call->args[0].same_as(add_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_add_s8"),
tir::StringImm("arm_elementwise_add_s" + std::to_string(bit_width)),
input_0,
input_1,
ToArg(input_0_offset),
Expand All @@ -710,7 +686,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(input_1_shift),
ToArg(left_shift),
output,
ToArg(output_zero_point),
ToArg(output_offset),
ToArg(output_multiplier),
ToArg(output_shift),
ToArg(output_min),
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
}
std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" ||
cmsis_func_name == "arm_elementwise_add_s8") {
cmsis_func_name == "arm_elementwise_add_s8" ||
cmsis_func_name == "arm_elementwise_mul_s16" ||
cmsis_func_name == "arm_elementwise_add_s16") {
CodeGenC::VisitExpr_(op, os);
} else if (cmsis_func_name == "arm_convolve_wrapper_s8" ||
cmsis_func_name == "arm_convolve_wrapper_s16" ||
Expand Down
127 changes: 126 additions & 1 deletion tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,131 @@ def test_op_int8(
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
@pytest.mark.parametrize(
[
"input_0_scale",
"input_1_scale",
"output_scale",
],
[
[0.256, 0.256, 0.256],
[0.0128, 0.0128, 0.0128],
[0.0128, 0.256, 0.256],
],
)
@pytest.mark.parametrize(
"compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")]
)
def test_op_int16(
op,
relu_type,
input_0_scale,
input_1_scale,
output_scale,
compiler_cpu,
cpu_flags,
):
"""Tests QNN 16bit binary operators for CMSIS-NN"""
interface_api = "c"
use_unpacked_api = True

dtype = "int16"
shape = [1, 16, 16, 3]
model = make_model(
op,
generate_variable("input_0", dtype),
generate_variable("input_1", dtype),
input_0_scale,
0,
input_1_scale,
0,
relu_type,
output_scale,
0,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_dtype_range(dtype)
inputs = {
"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=1,
),
create_test_runner(compiler_cpu, cpu_flags),
interface_api,
use_unpacked_api,
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_scale",
"output_zero_point",
],
[
[0.256, 0, 0.256, 33, 0.256, 33],
[0.0128, -64, 0.0128, 0, 0.0128, -64],
[0.0128, -64, 0.256, 33, 0.256, 0],
],
)
def test_op_int16_cannot_partition(
op,
relu_type,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
output_scale,
output_zero_point,
):
"""Tests QNN 16bit binary operators for CMSIS-NN in the edge case of
non-zero zero points"""

model = make_model(
op,
generate_variable("input_0", "int16"),
generate_variable("input_1", "int16"),
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
relu_type,
output_scale,
output_zero_point,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# arm_elementwise_(mul|add)_s16 does not support non-zero shifts in any
# argument
assert_no_external_function(cmsisnn_mod)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
Expand Down Expand Up @@ -320,7 +445,7 @@ def test_both_scalar_inputs_int8(
@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["uint16"]])
def test_invalid_parameters(
op,
input_dtype,
Expand Down