From 27ec6a2f97a7425772720f2e3c4bcd6d1c94f217 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 6 Feb 2023 11:12:49 +0000 Subject: [PATCH] Add CMSIS-NN int16 add and mul operator support The CMSIS-NN backend will now partition quantized `int16` additions and multiplication operators and emit calls to `arm_elementwise_add_s16` and `arm_elementwise_mul_s16` respectively during codegen. Because `arm_elementwise_mul_s16` and `arm_elementwise_add_s16` do not handle non-zero shift parameters at present, for non-zero zero point values (which map directly to shift in the CMSIS-NN backend) partitioning fails and we fall back on the regular C codegen path. This patch also adds `int16` tests, including testing for the non-zero zero point edge case described above. --- python/tvm/relay/op/contrib/cmsisnn.py | 31 ++++- .../backend/contrib/cmsisnn/relay_to_tir.cc | 104 ++++++-------- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 4 +- .../contrib/test_cmsisnn/test_binary_ops.py | 127 +++++++++++++++++- 4 files changed, 193 insertions(+), 73 deletions(-) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 4581378dcd24..cf3294744693 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -340,20 +340,37 @@ def check_qnn_binary_op(pattern): arg0 = binary_op.args[0] arg1 = binary_op.args[1] - both_args_scalar = False + + # Check arguments are not scalar. if ( isinstance(arg0, tvm.relay.expr.Constant) and len(arg0.checked_type.shape) == 0 and isinstance(arg1, tvm.relay.expr.Constant) and len(arg1.checked_type.shape) == 0 ): - both_args_scalar = True + return False - return ( - arg0.checked_type.dtype == "int8" - and arg1.checked_type.dtype == "int8" - and not both_args_scalar - ) + arg0_type = arg0.checked_type.dtype + arg1_type = arg1.checked_type.dtype + + # Check arguments are of valid type. + if arg0_type not in ["int8", "int16"]: + return False + + # Check arguments are the same type. + if arg0_type != arg1_type: + return False + + # Check zero points are non-zero (arm_elementwise_(add|mul)_s16 does not + # handle non-zero zero points). + if arg0_type == "int16" and str(binary_op.op.name) in ["qnn.add", "qnn.mul"]: + arg_0_zero_point = binary_op.args[3].data.numpy() + arg_1_zero_point = binary_op.args[5].data.numpy() + output_zero_point = binary_op.args[7].data.numpy() + if any([arg_0_zero_point, arg_1_zero_point, output_zero_point]): + return False + + return True return [ ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad), diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index f8685dc4df47..73d479e6944e 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -112,6 +112,21 @@ class RelayToTIRVisitor : public MixedModeMutator { ir_module_->Add(global_var, replacement_func); } + auto GetIntMinMax(int bit_width) { + const int32_t min = + (bit_width == 8) ? std::numeric_limits::min() : std::numeric_limits::min(); + const int32_t max = + (bit_width == 8) ? std::numeric_limits::max() : std::numeric_limits::max(); + return std::pair(min, max); + } + + auto GetClipMinMax(const ClipAttrs* clip_attrs) { + return std::pair(static_cast(clip_attrs->a_min), + static_cast(clip_attrs->a_max)); + } + + auto GetClipMinMax(const Call& clip_op) { return GetClipMinMax(clip_op->attrs.as()); } + void EmitConv2D(const GlobalVar& global_var, const Expr& expr) { const CallNode* clip_call = nullptr; const CallNode* requantize_call = nullptr; @@ -190,18 +205,8 @@ class RelayToTIRVisitor : public MixedModeMutator { int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]); int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels); std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); - int32_t clip_min = std::numeric_limits::min(); - int32_t clip_max = std::numeric_limits::max(); - - if (dtype_bits == 16) { - clip_min = std::numeric_limits::min(); - clip_max = std::numeric_limits::max(); - } - if (clip_call) { - const ClipAttrs* clip_attrs = clip_call->attrs.as(); - clip_min = clip_attrs->a_min; - clip_max = clip_attrs->a_max; - } + const auto [clip_min, clip_max] = + clip_call ? GetClipMinMax(GetRef(clip_call)) : GetIntMinMax(dtype_bits); tvm::Array scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w), ToArg(stride_h), ToArg(padding_w), ToArg(padding_h), @@ -348,20 +353,8 @@ class RelayToTIRVisitor : public MixedModeMutator { float input_scale = GetScalarFromConstant(requantize_call->args[1]); float output_scale = GetScalarFromConstant(requantize_call->args[3]); int32_t out_channels = qnn::get_const_int(dense_attrs->units); - int32_t clip_min, clip_max; - if (clip_call) { - const ClipAttrs* clip_attrs = clip_call->attrs.as(); - clip_min = clip_attrs->a_min; - clip_max = clip_attrs->a_max; - } else { - if (dtype_bits == 8) { - clip_min = std::numeric_limits::min(); - clip_max = std::numeric_limits::max(); - } else { - clip_min = std::numeric_limits::min(); - clip_max = std::numeric_limits::max(); - } - } + const auto [clip_min, clip_max] = + clip_call ? GetClipMinMax(GetRef(clip_call)) : GetIntMinMax(dtype_bits); double quantized_multiplier = static_cast(input_scale) / static_cast(output_scale); @@ -432,7 +425,6 @@ class RelayToTIRVisitor : public MixedModeMutator { // prepare cmsis_nn_pool_params int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w; - int32_t clip_min, clip_max; std::string cmsisnn_api; if (pool_name == "cmsis-nn.qnn_avg_pool2d") { if (dtype_bits == 8) { @@ -463,19 +455,9 @@ class RelayToTIRVisitor : public MixedModeMutator { pool_size_h = qnn::get_const_int(attrs->pool_size[0]); pool_size_w = qnn::get_const_int(attrs->pool_size[1]); } - if (clip.defined()) { - const ClipAttrs* clip_attrs = clip->attrs.as(); - clip_min = clip_attrs->a_min; - clip_max = clip_attrs->a_max; - } else { - if (dtype_bits == 8) { - clip_min = std::numeric_limits::min(); - clip_max = std::numeric_limits::max(); - } else { - clip_min = std::numeric_limits::min(); - clip_max = std::numeric_limits::max(); - } - } + + const auto [clip_min, clip_max] = + clip.defined() ? GetClipMinMax(clip) : GetIntMinMax(dtype_bits); tvm::Array scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h), ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)}; @@ -587,15 +569,11 @@ class RelayToTIRVisitor : public MixedModeMutator { } void EmitMul(const GlobalVar& global_var, const Expr& expr) { - int32_t output_min = std::numeric_limits::min(); - int32_t output_max = std::numeric_limits::max(); const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr); Call mul_call = pattern.binary_op; - if (pattern.clip_op) { - const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as(); - output_min = clip_attrs->a_min; - output_max = clip_attrs->a_max; - } + const auto bit_width = mul_call->type_as()->dtype.bits(); + const auto [output_min, output_max] = + pattern.clip_op ? GetClipMinMax(pattern.clip_op.value()) : GetIntMinMax(bit_width); const float input_0_scale = GetScalarFromConstant(mul_call->args[2]); const int32_t input_0_zero_point = GetScalarFromConstant(mul_call->args[3]); @@ -614,17 +592,17 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = mul_call->type_as()->Size(); BufferCreator buffer_creator; - tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width)); tir::Var input_1; if (mul_call->args[0].same_as(mul_call->args[1])) { input_1 = input_0; } else { - input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width)); } - tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width)); tvm::Array args = { - tir::StringImm("arm_elementwise_mul_s8"), + tir::StringImm("arm_elementwise_mul_s" + std::to_string(bit_width)), input_0, input_1, ToArg(-input_0_zero_point), @@ -643,15 +621,12 @@ class RelayToTIRVisitor : public MixedModeMutator { } void EmitAdd(const GlobalVar& global_var, const Expr& expr) { - int32_t output_min = std::numeric_limits::min(); - int32_t output_max = std::numeric_limits::max(); const auto& pattern = ParseBinaryElementwiseOpClipPattern(expr); Call add_call = pattern.binary_op; - if (pattern.clip_op) { - const ClipAttrs* clip_attrs = pattern.clip_op.value()->attrs.as(); - output_min = clip_attrs->a_min; - output_max = clip_attrs->a_max; - } + const auto bit_width = add_call->type_as()->dtype.bits(); + + const auto [output_min, output_max] = + pattern.clip_op ? GetClipMinMax(pattern.clip_op.value()) : GetIntMinMax(bit_width); const float input_0_scale = GetScalarFromConstant(add_call->args[2]); const int32_t input_0_zero_point = GetScalarFromConstant(add_call->args[3]); @@ -660,9 +635,10 @@ class RelayToTIRVisitor : public MixedModeMutator { const float output_scale = GetScalarFromConstant(add_call->args[6]); const int32_t output_zero_point = GetScalarFromConstant(add_call->args[7]); - const int32_t left_shift = 20; + const int32_t left_shift = (bit_width == 16) ? 15 : 20; const int32_t input_0_offset = -input_0_zero_point; const int32_t input_1_offset = -input_1_zero_point; + const int32_t output_offset = output_zero_point; const float max_input_scale = std::max(input_0_scale, input_1_scale); const double twice_max_input_scale = 2 * static_cast(max_input_scale); @@ -689,17 +665,17 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = add_call->type_as()->Size(); BufferCreator buffer_creator; - tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(bit_width)); tir::Var input_1; if (add_call->args[0].same_as(add_call->args[1])) { input_1 = input_0; } else { - input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(bit_width)); } - tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width)); tvm::Array args = { - tir::StringImm("arm_elementwise_add_s8"), + tir::StringImm("arm_elementwise_add_s" + std::to_string(bit_width)), input_0, input_1, ToArg(input_0_offset), @@ -710,7 +686,7 @@ class RelayToTIRVisitor : public MixedModeMutator { ToArg(input_1_shift), ToArg(left_shift), output, - ToArg(output_zero_point), + ToArg(output_offset), ToArg(output_multiplier), ToArg(output_shift), ToArg(output_min), diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 1d53373ba833..a592eb74b4fc 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -108,7 +108,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { } std::string cmsis_func_name = op->args[0].as()->value; if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" || - cmsis_func_name == "arm_elementwise_add_s8") { + cmsis_func_name == "arm_elementwise_add_s8" || + cmsis_func_name == "arm_elementwise_mul_s16" || + cmsis_func_name == "arm_elementwise_add_s16") { CodeGenC::VisitExpr_(op, os); } else if (cmsis_func_name == "arm_convolve_wrapper_s8" || cmsis_func_name == "arm_convolve_wrapper_s16" || diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index 663a1bd45d5c..8c0da922f093 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -153,6 +153,131 @@ def test_op_int8( ) +@skip_if_no_reference_system +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +@pytest.mark.parametrize( + [ + "input_0_scale", + "input_1_scale", + "output_scale", + ], + [ + [0.256, 0.256, 0.256], + [0.0128, 0.0128, 0.0128], + [0.0128, 0.256, 0.256], + ], +) +@pytest.mark.parametrize( + "compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")] +) +def test_op_int16( + op, + relu_type, + input_0_scale, + input_1_scale, + output_scale, + compiler_cpu, + cpu_flags, +): + """Tests QNN 16bit binary operators for CMSIS-NN""" + interface_api = "c" + use_unpacked_api = True + + dtype = "int16" + shape = [1, 16, 16, 3] + model = make_model( + op, + generate_variable("input_0", dtype), + generate_variable("input_1", dtype), + input_0_scale, + 0, + input_1_scale, + 0, + relu_type, + output_scale, + 0, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # validate the output + in_min, in_max = get_dtype_range(dtype) + inputs = { + "input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), + "input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), + } + output_list = generate_ref_data(orig_mod["main"], inputs) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + output_tolerance=1, + ), + create_test_runner(compiler_cpu, cpu_flags), + interface_api, + use_unpacked_api, + ) + + +@skip_if_no_reference_system +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +@pytest.mark.parametrize( + [ + "input_0_scale", + "input_0_zero_point", + "input_1_scale", + "input_1_zero_point", + "output_scale", + "output_zero_point", + ], + [ + [0.256, 0, 0.256, 33, 0.256, 33], + [0.0128, -64, 0.0128, 0, 0.0128, -64], + [0.0128, -64, 0.256, 33, 0.256, 0], + ], +) +def test_op_int16_cannot_partition( + op, + relu_type, + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, + output_scale, + output_zero_point, +): + """Tests QNN 16bit binary operators for CMSIS-NN in the edge case of + non-zero zero points""" + + model = make_model( + op, + generate_variable("input_0", "int16"), + generate_variable("input_1", "int16"), + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, + relu_type, + output_scale, + output_zero_point, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # arm_elementwise_(mul|add)_s16 does not support non-zero shifts in any + # argument + assert_no_external_function(cmsisnn_mod) + + @skip_if_no_reference_system @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) @@ -320,7 +445,7 @@ def test_both_scalar_inputs_int8( @skip_if_no_reference_system @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) -@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]]) +@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["uint16"]]) def test_invalid_parameters( op, input_dtype,