diff --git a/tensorflow/lite/micro/kernels/lstm_eval_test.cc b/tensorflow/lite/micro/kernels/lstm_eval_test.cc index 82c09ede033..eaba2c4ac2f 100644 --- a/tensorflow/lite/micro/kernels/lstm_eval_test.cc +++ b/tensorflow/lite/micro/kernels/lstm_eval_test.cc @@ -34,7 +34,7 @@ namespace { // Test Settings constexpr float kTestFloatTolerance = 1e-6f; } // namespace -#endif // !defined(XTENSA) +#endif // !defined(XTENSA) TF_LITE_MICRO_TESTS_BEGIN // TODO(b/230666079) enable below tests for xtensa when the xtensa @@ -454,6 +454,6 @@ TF_LITE_MICRO_TEST(TestLSTMEvalInt16) { cell_state_tolerance, int16_node_contents); } -#endif // !defined(XTENSA) +#endif // !defined(XTENSA) TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc index 5a671a31aa1..659f72d736b 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc @@ -225,11 +225,10 @@ void FullyConnected(const FullyConnectedParams& params, params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data); } -#else // #if !(defined(HIFI5) || defined(HIFI4)) +#else // #if !(defined(HIFI5) || defined(HIFI4)) void Sigmoid(int16_t* data, int32_t data_size) { WORD32 err; - err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0, - data_size); + err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0, data_size); } void Sigmoid(float* data, int32_t data_size) { @@ -245,7 +244,8 @@ void Tanh(int32_t cell_state_scale_power, int16_t* input_data, if (tanh_input_left_shift < 0) /* handling negative shift value */ { tanh_input_left_shift = -tanh_input_left_shift; -#if (defined(USE_HIFI_ACT_TIE) && (defined(AE_TANH16X4X2) || defined(AE_TANH16X4))) +#if (defined(USE_HIFI_ACT_TIE) && \ + (defined(AE_TANH16X4X2) || defined(AE_TANH16X4))) input_multiplier = 1; #else input_multiplier = 3; @@ -253,25 +253,24 @@ void Tanh(int32_t cell_state_scale_power, int16_t* input_data, } WORD32 err; err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier, - tanh_input_left_shift, data_size); + tanh_input_left_shift, data_size); } void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data, int32_t data_size) { int data_dims[2] = {1, data_size}; RuntimeShape data_shape(2, reinterpret_cast(data_dims)); - reference_ops::Tanh(data_shape, input_data, data_shape, - output_data); + reference_ops::Tanh(data_shape, input_data, data_shape, output_data); } // Input and output have the same shape in LSTM void Mul(const ArithmeticParams& params, const int16_t* input1_data, const int16_t* input2_data, int8_t* output_data, int32_t data_size) { WORD32 err; - err = xa_nn_elm_mul_sym16sxsym16s_asym8s(output_data, params.output_offset, - params.output_shift, params.output_multiplier, - params.quantized_activation_min, params.quantized_activation_max, - input1_data, input2_data, data_size); + err = xa_nn_elm_mul_sym16sxsym16s_asym8s( + output_data, params.output_offset, params.output_shift, + params.output_multiplier, params.quantized_activation_min, + params.quantized_activation_max, input1_data, input2_data, data_size); } // Input and output have the same shape in LSTM @@ -279,8 +278,8 @@ void Mul(const ArithmeticParams& params, const int16_t* input1_data, const int16_t* input2_data, int16_t* output_data, int32_t data_size) { int dims_4D[4] = {1, 1, 1, data_size}; WORD32 err; - err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(output_data, - dims_4D, params.output_shift, params.output_multiplier, + err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s( + output_data, dims_4D, params.output_shift, params.output_multiplier, params.quantized_activation_min, params.quantized_activation_max, input1_data, dims_4D, input2_data, dims_4D); return; @@ -301,8 +300,8 @@ void FullyConnected(const FullyConnectedParams& params, const int num_batches, const int output_depth, const int accum_depth) { WORD32 err; -#pragma loop_count min=1 - for(int b = 0; b < num_batches; b++) { +#pragma loop_count min = 1 + for (int b = 0; b < num_batches; b++) { err = xa_nn_matXvec_out_stride_sym8sxasym8s_16( output_data + b * output_depth, filter_data, input_data + b * accum_depth, bias_data, output_depth, accum_depth, @@ -319,18 +318,18 @@ void FullyConnected(const FullyConnectedParams& params, const int accum_depth) { WORD32 err; - err = xa_nn_matmul_sym8sxsym16s_sym16s(output_data, filter_data, input_data, - bias_data, output_depth, accum_depth, accum_depth, num_batches, - accum_depth, output_depth, 1, params.input_offset, - params.output_multiplier, params.output_shift, params.output_offset); + err = xa_nn_matmul_sym8sxsym16s_sym16s( + output_data, filter_data, input_data, bias_data, output_depth, + accum_depth, accum_depth, num_batches, accum_depth, output_depth, 1, + params.input_offset, params.output_multiplier, params.output_shift, + params.output_offset); return; } -void FullyConnected(const FullyConnectedParams& params, - const float* input_data, const float* filter_data, - const float* bias_data, float* output_data, - const int num_batches, const int output_depth, - const int accum_depth) { +void FullyConnected(const FullyConnectedParams& params, const float* input_data, + const float* filter_data, const float* bias_data, + float* output_data, const int num_batches, + const int output_depth, const int accum_depth) { int input_dims[2] = {num_batches, output_depth}; RuntimeShape input_shape(2, reinterpret_cast(input_dims)); RuntimeShape bias_shape(1, bias_data == NULL ? 0 : output_depth); @@ -342,7 +341,7 @@ void FullyConnected(const FullyConnectedParams& params, params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data); } -#endif // #if !(defined(HIFI5) || defined(HIFI4)) +#endif // #if !(defined(HIFI5) || defined(HIFI4)) void Clipping(const int v_size, const CellStateInfo& cell_state_info, int16_t* vector) { @@ -372,18 +371,16 @@ void UpdateLstmCell(const LstmStepManager& step_info, const ArithmeticParams& forget_cell_mul_params, const ArithmeticParams& input_mul_params, const CellStateInfo& cell_state_info, int16_t* buffer) { - auto cell_state_shape = step_info.StateShape(); // Check offset validity to avoid memory overflow - TFLITE_DCHECK_LE( - step_info.CellStateOffset() + cell_state_shape.FlatSize(), - tflite::micro::GetTensorShape(cell_state).FlatSize()); + TFLITE_DCHECK_LE(step_info.CellStateOffset() + cell_state_shape.FlatSize(), + tflite::micro::GetTensorShape(cell_state).FlatSize()); WORD32 err; // Multiplier is equivalent to 0.5 here so adding 1 to shifts err = xa_nn_lstm_cell_state_update_16( tflite::micro::GetTensorData(cell_state) + - step_info.CellStateOffset(), + step_info.CellStateOffset(), forget_gate_output, cell_gate_output, input_gate_output, forget_cell_mul_params.output_shift - 1, input_mul_params.output_shift - 1, cell_state_info.quantized_cell_clip, @@ -393,8 +390,7 @@ void UpdateLstmCell(const LstmStepManager& step_info, void UpdateLstmCell(const LstmStepManager& step_info, TfLiteEvalTensor* cell_state, // Gate outputs - float* forget_gate_output, - const float* input_gate_output, + float* forget_gate_output, const float* input_gate_output, const float* cell_gate_output, // Mul parameters const ArithmeticParams& forget_cell_mul_params, @@ -432,7 +428,7 @@ void UpdateLstmCell(const LstmStepManager& step_info, step_info.CellStateOffset()); } } -#endif // #if defined(HIFI5) || defined(HIFI4) +#endif // #if defined(HIFI5) || defined(HIFI4) // Increment the data offset so the sigle time step invocation call can access // the corresponding input/output tensor data at the time step @@ -491,6 +487,5 @@ RuntimeShape LstmStepManager::StateShape() const { return RuntimeShape(2, dims_data); } - } // namespace lstm_internal } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h index 1162522e86a..5522d1253c3 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/lstm_shared.h" -#include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" namespace tflite { @@ -200,7 +200,7 @@ void FullyConnected(const FullyConnectedParams& params, const RuntimeShape& filter_shape, const float* filter_data, const RuntimeShape& bias_shape, const float* bias_data, const RuntimeShape& output_shape, float* output_data); -#else // #if !(defined(HIFI5) || defined(HIFI4)) +#else // #if !(defined(HIFI5) || defined(HIFI4)) void Sigmoid(int16_t* data, int32_t data_size); void Sigmoid(float* data, int32_t data_size); @@ -208,8 +208,8 @@ void Sigmoid(float* data, int32_t data_size); void Tanh(int32_t cell_state_scale_power, int16_t* input_data, int16_t* output_data, int32_t data_size); -void Tanh(int32_t cell_state_scale_power, float* input_data, - float* output_data, int32_t data_size); +void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data, + int32_t data_size); void Mul(const ArithmeticParams& params, const int16_t* input1_data, const int16_t* input2_data, int8_t* output_data, int32_t data_size); @@ -232,12 +232,11 @@ void FullyConnected(const FullyConnectedParams& params, const int num_batches, const int output_depth, const int accum_depth); -void FullyConnected(const FullyConnectedParams& params, - const float* input_data, const float* filter_data, - const float* bias_data, float* output_data, - const int num_batches, const int output_depth, - const int accum_depth); -#endif // #if !(defined(HIFI5) || defined(HIFI4)) +void FullyConnected(const FullyConnectedParams& params, const float* input_data, + const float* filter_data, const float* bias_data, + float* output_data, const int num_batches, + const int output_depth, const int accum_depth); +#endif // #if !(defined(HIFI5) || defined(HIFI4)) void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input, int16_t* output); @@ -407,7 +406,7 @@ void UpdateLstmCell(const LstmStepManager& step_info, step_info.CellStateOffset()); } } -#else // #if !defined(HIFI5) || defined(HIFI4) +#else // #if !defined(HIFI5) || defined(HIFI4) template void CalculateLstmGate( @@ -424,7 +423,6 @@ void CalculateLstmGate( CellType* fc_output_buffer, const TfLiteFusedActivation activation, const int num_batches, const int input_dimension, const int state_dimension) { - // RuntimeShape step_input_shape = step_info.InputShape(); // RuntimeShape input_shape = tflite::micro::GetTensorShape(input); // RuntimeShape step_state_shape = step_info.StateShape(); @@ -433,10 +431,10 @@ void CalculateLstmGate( // Moved these to LstmStep function // Check offset validity to avoid memory overflow // TFLITE_DCHECK_LE(step_info.InputOffset() + step_input_shape.FlatSize(), - // input_shape.FlatSize()); + // input_shape.FlatSize()); // TFLITE_DCHECK_LE( - // step_info.HiddenStateOffset() + step_state_shape.FlatSize(), - // recurrent_shape.FlatSize()); + // step_info.HiddenStateOffset() + step_state_shape.FlatSize(), + // recurrent_shape.FlatSize()); // Input FC FullyConnected(gate_params.input_fc_params, @@ -492,14 +490,13 @@ void UpdateLstmCell(const LstmStepManager& step_info, void UpdateLstmCell(const LstmStepManager& step_info, TfLiteEvalTensor* cell_state, // Gate outputs - float* forget_gate_output, - const float* input_gate_output, + float* forget_gate_output, const float* input_gate_output, const float* cell_gate_output, // Mul parameters const ArithmeticParams& forget_cell_mul_params, const ArithmeticParams& input_mul_params, const CellStateInfo& cell_state_info, float* buffer); -#endif // #if defined(HIFI5) || defined(HIFI4) +#endif // #if defined(HIFI5) || defined(HIFI4) // Update the hidden state of the LSTM kernel using the following formula: // updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means @@ -533,8 +530,7 @@ void UpdateLstmHidden(const LstmStepManager& step_info, step_info.HiddenStateOffset()); #else int32_t cell_state_size = cell_state_shape.FlatSize(); - Tanh(cell_state_scale_power, cell_state_data, buffer, - cell_state_size); + Tanh(cell_state_scale_power, cell_state_data, buffer, cell_state_size); // Update the hidden state Mul(mul_params, buffer, output_gate_output, tflite::micro::GetTensorData(hidden_state) + @@ -655,13 +651,12 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, step_info.HiddenStateOffset(), step_info.StateShape().FlatSize() * sizeof(ActivationType)); } -#else // #if !(defined(HIFI5) || defined(HIFI4)) +#else // #if !(defined(HIFI5) || defined(HIFI4)) template void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, LSTMKernelContents& kernel_content, const LSTMBuffers& buffers) { - const TfLiteEvalTensor* input = kernel_content.GetInternalTensor(tflite::kLstmInputTensor); TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor(); @@ -684,11 +679,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, CalculateLstmGate( step_info, op_data.forget_gate_parameters, // Input FC - input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), + input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor), kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor), // Recurrent FC - recurrent, // kernel_content.HiddenStateTensor(), + recurrent, // kernel_content.HiddenStateTensor(), kernel_content.GetInternalTensor( tflite::kLstmRecurrentToForgetWeightsTensor), /*recurrent_bias*/ nullptr, @@ -703,11 +698,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, CalculateLstmGate( step_info, op_data.input_gate_parameters, // Input FC - input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), + input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor), kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor), // Recurrent FC - recurrent, // kernel_content.HiddenStateTensor(), + recurrent, // kernel_content.HiddenStateTensor(), kernel_content.GetInternalTensor( tflite::kLstmRecurrentToInputWeightsTensor), /*recurrent_bias*/ nullptr, @@ -722,11 +717,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, CalculateLstmGate( step_info, op_data.cell_gate_parameters, // Input FC - input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), + input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor), kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor), // Recurrent FC - recurrent, // kernel_content.HiddenStateTensor(), + recurrent, // kernel_content.HiddenStateTensor(), kernel_content.GetInternalTensor( tflite::kLstmRecurrentToCellWeightsTensor), /*recurrent_bias*/ nullptr, @@ -741,22 +736,21 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, CellType* updated_input_buffer = buffers.buffer1; // reuse buffer UpdateLstmCell(step_info, kernel_content.CellStateTensor(), - forget_gate_output, input_gate_output, - cell_gate_output, - inter_gate_params.forget_cell_mul_params, - inter_gate_params.input_mul_params, - op_data.cell_state_info, updated_input_buffer); + forget_gate_output, input_gate_output, cell_gate_output, + inter_gate_params.forget_cell_mul_params, + inter_gate_params.input_mul_params, op_data.cell_state_info, + updated_input_buffer); /*Step3: update the hidden state */ CellType* output_gate_output = buffers.buffer1; // reuse buffer CalculateLstmGate( step_info, op_data.output_gate_parameters, // Input FC - input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), + input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor), kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor), kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor), // Recurrent FC - recurrent, // kernel_content.HiddenStateTensor(), + recurrent, // kernel_content.HiddenStateTensor(), kernel_content.GetInternalTensor( tflite::kLstmRecurrentToOutputWeightsTensor), /*recurrent_bias*/ nullptr, @@ -768,8 +762,8 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, CellType* tanh_activated_cell_buffer = buffers.buffer0; // reuse buffer tflite::lstm_internal::UpdateLstmHidden( - step_info, kernel_content.CellStateTensor(), - recurrent, /* kernel_content.HiddenStateTensor(), */ output_gate_output, + step_info, kernel_content.CellStateTensor(), recurrent, + /* kernel_content.HiddenStateTensor(), */ output_gate_output, inter_gate_params.output_mul_params, op_data.cell_state_info.cell_state_scale_power, tanh_activated_cell_buffer); @@ -788,7 +782,7 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, step_info.HiddenStateOffset(), step_info.StateShape().FlatSize() * sizeof(ActivationType)); } -#endif // #if !(defined(HIFI5) || defined(HIFI4)) +#endif // #if !(defined(HIFI5) || defined(HIFI4)) } // namespace lstm_internal diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc index a5b5fc782cb..d44162ae247 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc @@ -12,55 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if defined(HIFI4) || defined(HIFI5) + +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h" #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" -#include namespace tflite { #if defined(HIFI5) #if TFLITE_SINGLE_ROUNDING -#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, right_shift) \ -{ \ - ae_int64 out64_0, out64_1; \ - ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0,1)); \ - ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \ - AE_MUL32X2S_HH_LL(out64_0, out64_1, inp, AE_MOVDA32(multiplier)); \ - out64_0 = AE_ADD64S(out64_0, round_val); \ - out64_1 = AE_ADD64S(out64_1, round_val); \ - out = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \ -} +#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, \ + right_shift) \ + { \ + ae_int64 out64_0, out64_1; \ + ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1)); \ + ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \ + AE_MUL32X2S_HH_LL(out64_0, out64_1, inp, AE_MOVDA32(multiplier)); \ + out64_0 = AE_ADD64S(out64_0, round_val); \ + out64_1 = AE_ADD64S(out64_1, round_val); \ + out = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \ + } -#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, left_shift, right_shift) \ -{ \ - ae_int64 out64_0, out64_1, out64_2, out64_3; \ - ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0,1)); \ - ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \ - AE_MUL32X2S_HH_LL(out64_0, out64_1, inp1, AE_MOVDA32(multiplier)); \ - AE_MUL32X2S_HH_LL(out64_2, out64_3, inp2, AE_MOVDA32(multiplier)); \ - out64_0 = AE_ADD64S(out64_0, round_val); \ - out64_1 = AE_ADD64S(out64_1, round_val); \ - out64_2 = AE_ADD64S(out64_2, round_val); \ - out64_3 = AE_ADD64S(out64_3, round_val); \ - out1 = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \ - out2 = AE_TRUNCA32X2F64S(out64_2, out64_3, 1 + left_shift); \ -} +#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \ + left_shift, right_shift) \ + { \ + ae_int64 out64_0, out64_1, out64_2, out64_3; \ + ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1)); \ + ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \ + AE_MUL32X2S_HH_LL(out64_0, out64_1, inp1, AE_MOVDA32(multiplier)); \ + AE_MUL32X2S_HH_LL(out64_2, out64_3, inp2, AE_MOVDA32(multiplier)); \ + out64_0 = AE_ADD64S(out64_0, round_val); \ + out64_1 = AE_ADD64S(out64_1, round_val); \ + out64_2 = AE_ADD64S(out64_2, round_val); \ + out64_3 = AE_ADD64S(out64_3, round_val); \ + out1 = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \ + out2 = AE_TRUNCA32X2F64S(out64_2, out64_3, 1 + left_shift); \ + } #else /* #if TFLITE_SINGLE_ROUNDING */ -#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, right_shift) \ - out = AE_SLAA32(inp, left_shift); \ - out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier)); \ +#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, \ + right_shift) \ + out = AE_SLAA32(inp, left_shift); \ + out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier)); \ out = AE_SRAA32SYMS(out, right_shift); -#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, left_shift, right_shift) \ -{ \ - ae_int32x2 d_ls = AE_MOVDA32(1< 0 ? 0 : -shift; #endif /* #if TFLITE_SINGLE_ROUNDING */ @@ -460,10 +468,10 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0); AE_MUL16X4(data_ab_2, data_ab_3, data_a_1, data_b_1); - MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, - data_ab_0, data_ab_1, multiplier, left_shift, right_shift); - MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_2, data_ab_3, - data_ab_2, data_ab_3, multiplier, left_shift, right_shift); + MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1, + multiplier, left_shift, right_shift); + MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_2, data_ab_3, data_ab_2, data_ab_3, + multiplier, left_shift, right_shift); data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1); data_c_1 = AE_SAT16X4(data_ab_2, data_ab_3); data_c_0 = AE_SUB16S(data_c_0, d_zp); @@ -482,8 +490,8 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2); AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0); - MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, - multiplier, left_shift, right_shift); + MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift, + right_shift); data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0); data_c_0 = AE_SUB16S(data_c_0, d_zp); data_c = AE_SAT8X8X16(data_c_0, data_c_0); @@ -493,71 +501,80 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, #elif defined(HIFI4) #if TFLITE_SINGLE_ROUNDING #define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \ -{ \ - ae_int64 out64_0, out64_1; \ - out64_0 = AE_MUL32_HH(inp, AE_MOVDA32(multiplier)); \ - out64_1 = AE_MUL32_LL(inp, AE_MOVDA32(multiplier)); \ - out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \ - out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \ - out = AE_ROUND32X2F64SASYM(out64_0, out64_1); \ -} + { \ + ae_int64 out64_0, out64_1; \ + out64_0 = AE_MUL32_HH(inp, AE_MOVDA32(multiplier)); \ + out64_1 = AE_MUL32_LL(inp, AE_MOVDA32(multiplier)); \ + out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \ + out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \ + out = AE_ROUND32X2F64SASYM(out64_0, out64_1); \ + } -#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, l_shift, r_shift) \ -{ \ - ae_int64 out64_0, out64_1, out64_2, out64_3; \ - out64_0 = AE_MUL32_HH(inp1, AE_MOVDA32(multiplier)); \ - out64_1 = AE_MUL32_LL(inp1, AE_MOVDA32(multiplier)); \ - out64_2 = AE_MUL32_HH(inp2, AE_MOVDA32(multiplier)); \ - out64_3 = AE_MUL32_LL(inp2, AE_MOVDA32(multiplier)); \ - out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \ - out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \ - out64_2 = AE_SLAA64S(out64_2, 1 + l_shift); \ - out64_3 = AE_SLAA64S(out64_3, 1 + l_shift); \ - out1 = AE_ROUND32X2F64SASYM(out64_0, out64_1); \ - out2 = AE_ROUND32X2F64SASYM(out64_2, out64_3); \ -} +#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \ + l_shift, r_shift) \ + { \ + ae_int64 out64_0, out64_1, out64_2, out64_3; \ + out64_0 = AE_MUL32_HH(inp1, AE_MOVDA32(multiplier)); \ + out64_1 = AE_MUL32_LL(inp1, AE_MOVDA32(multiplier)); \ + out64_2 = AE_MUL32_HH(inp2, AE_MOVDA32(multiplier)); \ + out64_3 = AE_MUL32_LL(inp2, AE_MOVDA32(multiplier)); \ + out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \ + out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \ + out64_2 = AE_SLAA64S(out64_2, 1 + l_shift); \ + out64_3 = AE_SLAA64S(out64_3, 1 + l_shift); \ + out1 = AE_ROUND32X2F64SASYM(out64_0, out64_1); \ + out2 = AE_ROUND32X2F64SASYM(out64_2, out64_3); \ + } #else /* #if TFLITE_SINGLE_ROUNDING */ #define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \ - out = AE_SLAA32(inp, l_shift); \ - out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier)); \ - out = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out), r_shift), AE_SRAA64(AE_CVT64F32_L(out), r_shift)); - -#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, l_shift, r_shift) \ -{ \ - ae_int32x2 d_ls = AE_MOVDA32(1< 0 ? 0 : -shift; #endif /* #if TFLITE_SINGLE_ROUNDING */ @@ -970,8 +985,8 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, AE_LA16X4_IP(data_b_0, align_src_input_2, tmp_input_2); AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0); - MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, - data_ab_0, data_ab_1, multiplier, left_shift, right_shift); + MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1, + multiplier, left_shift, right_shift); data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1); data_c_0 = AE_SUB16S(data_c_0, d_zp); AE_MINMAX16(data_c_0, d_min8, d_max8); @@ -982,7 +997,7 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, *output++ = AE_MOVAD16_0(data_c_0); } -// residue iterations + // residue iterations #pragma concurrent #pragma loop_count max = 3 for (int j = 0; j < ((num_elms)&3); j++) { @@ -990,8 +1005,8 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2); AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0); - MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, - multiplier, left_shift, right_shift); + MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift, + right_shift); data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0); data_c_0 = AE_SUB16S(data_c_0, d_zp); AE_MINMAX16(data_c_0, d_min8, d_max8); @@ -1002,3 +1017,5 @@ void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1, #endif // defined(HIFI5) } // namespace tflite + +#endif // defined(HIFI4) || defined(HIFI5)