Skip to content

Commit

Permalink
Add CMSIS-NN int16 add and mul operator support
Browse files Browse the repository at this point in the history
The CMSIS-NN backend will now partition quantized `int16` additions and
multiplication operators and emit calls to `arm_elementwise_add_s16`
and `arm_elementwise_mul_s16` respectively during codegen.

Because `arm_elementwise_mul_s16` and `arm_elementwise_add_s16` do not
handle non-zero shift parameters at present, for non-zero zero point
values (which map directly to shift in the CMSIS-NN backend)
partitioning fails and we fall back on the regular C codegen path.

This patch also adds `int16` tests, including testing for the non-zero
zero point edge case described above.
  • Loading branch information
FranklandJack authored and jacfra01 (generated by with_the_same_user script) committed Feb 6, 2023
1 parent f7aeaf1 commit 55e4b4d
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 23 deletions.
31 changes: 24 additions & 7 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,37 @@ def check_qnn_binary_op(pattern):

arg0 = binary_op.args[0]
arg1 = binary_op.args[1]
both_args_scalar = False

# Check arguments are not scalar.
if (
isinstance(arg0, tvm.relay.expr.Constant)
and len(arg0.checked_type.shape) == 0
and isinstance(arg1, tvm.relay.expr.Constant)
and len(arg1.checked_type.shape) == 0
):
both_args_scalar = True
return False

return (
arg0.checked_type.dtype == "int8"
and arg1.checked_type.dtype == "int8"
and not both_args_scalar
)
arg0_type = arg0.checked_type.dtype
arg1_type = arg1.checked_type.dtype

# Check arguments are of valid type.
if arg0_type not in ["int8", "int16"]:
return False

# Check arguments are the same type.
if arg0_type != arg1_type:
return False

# Check zero points are non-zero (arm_elementwise_(add|mul)_s16 does not
# handle non-zero zero points).
if arg0_type == "int16" and str(binary_op.op.name) in ["qnn.add", "qnn.mul"]:
arg_0_zero_point = binary_op.args[3].data.numpy()
arg_1_zero_point = binary_op.args[5].data.numpy()
output_zero_point = binary_op.args[7].data.numpy()
if any([arg_0_zero_point, arg_1_zero_point, output_zero_point]):
return False

return True

return [
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad),
Expand Down
33 changes: 19 additions & 14 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,11 @@ class RelayToTIRVisitor : public MixedModeMutator {
}

void EmitMul(const GlobalVar& global_var, const Expr& expr) {
int32_t output_min = std::numeric_limits<int8_t>::min();
int32_t output_max = std::numeric_limits<int8_t>::max();
const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr);
Call mul_call = pattern.binary_op;
const auto bit_width = mul_call->type_as<TensorTypeNode>()->dtype.bits();
int32_t output_min = -(1 << (bit_width - 1));
int32_t output_max = (1 << (bit_width - 1)) - 1;
if (pattern.clip_op) {
const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as<ClipAttrs>();
output_min = clip_attrs->a_min;
Expand All @@ -614,17 +615,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width));
tir::Var input_1;
if (mul_call->args[0].same_as(mul_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
tir::StringImm("arm_elementwise_mul_s" + std::to_string(bit_width)),
input_0,
input_1,
ToArg(-input_0_zero_point),
Expand All @@ -643,10 +644,13 @@ class RelayToTIRVisitor : public MixedModeMutator {
}

void EmitAdd(const GlobalVar& global_var, const Expr& expr) {
int32_t output_min = std::numeric_limits<int8_t>::min();
int32_t output_max = std::numeric_limits<int8_t>::max();
const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr);
Call add_call = pattern.binary_op;
const auto bit_width = add_call->type_as<TensorTypeNode>()->dtype.bits();

int32_t output_min = -(1 << (bit_width - 1));
int32_t output_max = (1 << (bit_width - 1)) - 1;

if (pattern.clip_op) {
const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as<ClipAttrs>();
output_min = clip_attrs->a_min;
Expand All @@ -660,9 +664,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
const float output_scale = GetScalarFromConstant<float>(add_call->args[6]);
const int32_t output_zero_point = GetScalarFromConstant<int32_t>(add_call->args[7]);

const int32_t left_shift = 20;
const int32_t left_shift = (bit_width == 16) ? 15 : 20;
const int32_t input_0_offset = -input_0_zero_point;
const int32_t input_1_offset = -input_1_zero_point;
const int32_t output_offset = output_zero_point;

const float max_input_scale = std::max(input_0_scale, input_1_scale);
const double twice_max_input_scale = 2 * static_cast<double>(max_input_scale);
Expand All @@ -689,17 +694,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
PrimExpr tensor_size = add_call->type_as<TensorTypeNode>()->Size();

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width));
tir::Var input_1;
if (add_call->args[0].same_as(add_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_add_s8"),
tir::StringImm("arm_elementwise_add_s" + std::to_string(bit_width)),
input_0,
input_1,
ToArg(input_0_offset),
Expand All @@ -710,7 +715,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(input_1_shift),
ToArg(left_shift),
output,
ToArg(output_zero_point),
ToArg(output_offset),
ToArg(output_multiplier),
ToArg(output_shift),
ToArg(output_min),
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
}
std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" ||
cmsis_func_name == "arm_elementwise_add_s8") {
cmsis_func_name == "arm_elementwise_add_s8" ||
cmsis_func_name == "arm_elementwise_mul_s16" ||
cmsis_func_name == "arm_elementwise_add_s16") {
CodeGenC::VisitExpr_(op, os);
} else if (cmsis_func_name == "arm_convolve_wrapper_s8" ||
cmsis_func_name == "arm_convolve_wrapper_s16" ||
Expand Down
127 changes: 126 additions & 1 deletion tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,131 @@ def test_op_int8(
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
@pytest.mark.parametrize(
[
"input_0_scale",
"input_1_scale",
"output_scale",
],
[
[0.256, 0.256, 0.256],
[0.0128, 0.0128, 0.0128],
[0.0128, 0.256, 0.256],
],
)
@pytest.mark.parametrize(
"compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")]
)
def test_op_int16(
op,
relu_type,
input_0_scale,
input_1_scale,
output_scale,
compiler_cpu,
cpu_flags,
):
"""Tests QNN 16bit binary operators for CMSIS-NN"""
interface_api = "c"
use_unpacked_api = True

dtype = "int16"
shape = [1, 16, 16, 3]
model = make_model(
op,
generate_variable("input_0", dtype),
generate_variable("input_1", dtype),
input_0_scale,
0,
input_1_scale,
0,
relu_type,
output_scale,
0,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_dtype_range(dtype)
inputs = {
"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=1,
),
create_test_runner(compiler_cpu, cpu_flags),
interface_api,
use_unpacked_api,
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_scale",
"output_zero_point",
],
[
[0.256, 0, 0.256, 33, 0.256, 33],
[0.0128, -64, 0.0128, 0, 0.0128, -64],
[0.0128, -64, 0.256, 33, 0.256, 0],
],
)
def test_op_int16_cannot_partition(
op,
relu_type,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
output_scale,
output_zero_point,
):
"""Tests QNN 16bit binary operators for CMSIS-NN in the edge case of
non-zero zero points"""

model = make_model(
op,
generate_variable("input_0", "int16"),
generate_variable("input_1", "int16"),
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
relu_type,
output_scale,
output_zero_point,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# arm_elementwise_(mul|add)_s16 does not support non-zero shifts in any
# argument
assert_no_external_function(cmsisnn_mod)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
Expand Down Expand Up @@ -320,7 +445,7 @@ def test_both_scalar_inputs_int8(
@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["uint16"]])
def test_invalid_parameters(
op,
input_dtype,
Expand Down

0 comments on commit 55e4b4d

Please sign in to comment.