Skip to content

Commit

Permalink
Fix style and remove hifi code from VP6 build
Browse files Browse the repository at this point in the history
  • Loading branch information
rascani committed Sep 20, 2023
1 parent 63947fb commit 02480ed
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 181 deletions.
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/lstm_eval_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace {
// Test Settings
constexpr float kTestFloatTolerance = 1e-6f;
} // namespace
#endif // !defined(XTENSA)
#endif // !defined(XTENSA)

TF_LITE_MICRO_TESTS_BEGIN
// TODO(b/230666079) enable below tests for xtensa when the xtensa
Expand Down Expand Up @@ -454,6 +454,6 @@ TF_LITE_MICRO_TEST(TestLSTMEvalInt16) {
cell_state_tolerance,
int16_node_contents);
}
#endif // !defined(XTENSA)
#endif // !defined(XTENSA)

TF_LITE_MICRO_TESTS_END
63 changes: 29 additions & 34 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,10 @@ void FullyConnected(const FullyConnectedParams& params,
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}
#else // #if !(defined(HIFI5) || defined(HIFI4))
#else // #if !(defined(HIFI5) || defined(HIFI4))
void Sigmoid(int16_t* data, int32_t data_size) {
WORD32 err;
err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0,
data_size);
err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0, data_size);
}

void Sigmoid(float* data, int32_t data_size) {
Expand All @@ -245,42 +244,42 @@ void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
tanh_input_left_shift = -tanh_input_left_shift;
#if (defined(USE_HIFI_ACT_TIE) && (defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
#if (defined(USE_HIFI_ACT_TIE) && \
(defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
input_multiplier = 1;
#else
input_multiplier = 3;
#endif
}
WORD32 err;
err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier,
tanh_input_left_shift, data_size);
tanh_input_left_shift, data_size);
}

void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
int32_t data_size) {
int data_dims[2] = {1, data_size};
RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
reference_ops::Tanh(data_shape, input_data, data_shape,
output_data);
reference_ops::Tanh(data_shape, input_data, data_shape, output_data);
}

// Input and output have the same shape in LSTM
void Mul(const ArithmeticParams& params, const int16_t* input1_data,
const int16_t* input2_data, int8_t* output_data, int32_t data_size) {
WORD32 err;
err = xa_nn_elm_mul_sym16sxsym16s_asym8s(output_data, params.output_offset,
params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, input2_data, data_size);
err = xa_nn_elm_mul_sym16sxsym16s_asym8s(
output_data, params.output_offset, params.output_shift,
params.output_multiplier, params.quantized_activation_min,
params.quantized_activation_max, input1_data, input2_data, data_size);
}

// Input and output have the same shape in LSTM
void Mul(const ArithmeticParams& params, const int16_t* input1_data,
const int16_t* input2_data, int16_t* output_data, int32_t data_size) {
int dims_4D[4] = {1, 1, 1, data_size};
WORD32 err;
err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(output_data,
dims_4D, params.output_shift, params.output_multiplier,
err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(
output_data, dims_4D, params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, dims_4D, input2_data, dims_4D);
return;
Expand All @@ -301,8 +300,8 @@ void FullyConnected(const FullyConnectedParams& params,
const int num_batches, const int output_depth,
const int accum_depth) {
WORD32 err;
#pragma loop_count min=1
for(int b = 0; b < num_batches; b++) {
#pragma loop_count min = 1
for (int b = 0; b < num_batches; b++) {
err = xa_nn_matXvec_out_stride_sym8sxasym8s_16(
output_data + b * output_depth, filter_data,
input_data + b * accum_depth, bias_data, output_depth, accum_depth,
Expand All @@ -319,18 +318,18 @@ void FullyConnected(const FullyConnectedParams& params,
const int accum_depth) {
WORD32 err;

err = xa_nn_matmul_sym8sxsym16s_sym16s(output_data, filter_data, input_data,
bias_data, output_depth, accum_depth, accum_depth, num_batches,
accum_depth, output_depth, 1, params.input_offset,
params.output_multiplier, params.output_shift, params.output_offset);
err = xa_nn_matmul_sym8sxsym16s_sym16s(
output_data, filter_data, input_data, bias_data, output_depth,
accum_depth, accum_depth, num_batches, accum_depth, output_depth, 1,
params.input_offset, params.output_multiplier, params.output_shift,
params.output_offset);
return;
}

void FullyConnected(const FullyConnectedParams& params,
const float* input_data, const float* filter_data,
const float* bias_data, float* output_data,
const int num_batches, const int output_depth,
const int accum_depth) {
void FullyConnected(const FullyConnectedParams& params, const float* input_data,
const float* filter_data, const float* bias_data,
float* output_data, const int num_batches,
const int output_depth, const int accum_depth) {
int input_dims[2] = {num_batches, output_depth};
RuntimeShape input_shape(2, reinterpret_cast<const int32_t*>(input_dims));
RuntimeShape bias_shape(1, bias_data == NULL ? 0 : output_depth);
Expand All @@ -342,7 +341,7 @@ void FullyConnected(const FullyConnectedParams& params,
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}
#endif // #if !(defined(HIFI5) || defined(HIFI4))
#endif // #if !(defined(HIFI5) || defined(HIFI4))

void Clipping(const int v_size, const CellStateInfo& cell_state_info,
int16_t* vector) {
Expand Down Expand Up @@ -372,18 +371,16 @@ void UpdateLstmCell(const LstmStepManager& step_info,
const ArithmeticParams& forget_cell_mul_params,
const ArithmeticParams& input_mul_params,
const CellStateInfo& cell_state_info, int16_t* buffer) {

auto cell_state_shape = step_info.StateShape();
// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(
step_info.CellStateOffset() + cell_state_shape.FlatSize(),
tflite::micro::GetTensorShape(cell_state).FlatSize());
TFLITE_DCHECK_LE(step_info.CellStateOffset() + cell_state_shape.FlatSize(),
tflite::micro::GetTensorShape(cell_state).FlatSize());

WORD32 err;
// Multiplier is equivalent to 0.5 here so adding 1 to shifts
err = xa_nn_lstm_cell_state_update_16(
tflite::micro::GetTensorData<int16_t>(cell_state) +
step_info.CellStateOffset(),
step_info.CellStateOffset(),
forget_gate_output, cell_gate_output, input_gate_output,
forget_cell_mul_params.output_shift - 1,
input_mul_params.output_shift - 1, cell_state_info.quantized_cell_clip,
Expand All @@ -393,8 +390,7 @@ void UpdateLstmCell(const LstmStepManager& step_info,
void UpdateLstmCell(const LstmStepManager& step_info,
TfLiteEvalTensor* cell_state,
// Gate outputs
float* forget_gate_output,
const float* input_gate_output,
float* forget_gate_output, const float* input_gate_output,
const float* cell_gate_output,
// Mul parameters
const ArithmeticParams& forget_cell_mul_params,
Expand Down Expand Up @@ -432,7 +428,7 @@ void UpdateLstmCell(const LstmStepManager& step_info,
step_info.CellStateOffset());
}
}
#endif // #if defined(HIFI5) || defined(HIFI4)
#endif // #if defined(HIFI5) || defined(HIFI4)

// Increment the data offset so the sigle time step invocation call can access
// the corresponding input/output tensor data at the time step
Expand Down Expand Up @@ -491,6 +487,5 @@ RuntimeShape LstmStepManager::StateShape() const {
return RuntimeShape(2, dims_data);
}


} // namespace lstm_internal
} // namespace tflite
70 changes: 32 additions & 38 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/lstm_shared.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {

Expand Down Expand Up @@ -200,16 +200,16 @@ void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& filter_shape, const float* filter_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data);
#else // #if !(defined(HIFI5) || defined(HIFI4))
#else // #if !(defined(HIFI5) || defined(HIFI4))
void Sigmoid(int16_t* data, int32_t data_size);

void Sigmoid(float* data, int32_t data_size);

void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
int16_t* output_data, int32_t data_size);

void Tanh(int32_t cell_state_scale_power, float* input_data,
float* output_data, int32_t data_size);
void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
int32_t data_size);

void Mul(const ArithmeticParams& params, const int16_t* input1_data,
const int16_t* input2_data, int8_t* output_data, int32_t data_size);
Expand All @@ -232,12 +232,11 @@ void FullyConnected(const FullyConnectedParams& params,
const int num_batches, const int output_depth,
const int accum_depth);

void FullyConnected(const FullyConnectedParams& params,
const float* input_data, const float* filter_data,
const float* bias_data, float* output_data,
const int num_batches, const int output_depth,
const int accum_depth);
#endif // #if !(defined(HIFI5) || defined(HIFI4))
void FullyConnected(const FullyConnectedParams& params, const float* input_data,
const float* filter_data, const float* bias_data,
float* output_data, const int num_batches,
const int output_depth, const int accum_depth);
#endif // #if !(defined(HIFI5) || defined(HIFI4))

void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
int n_input, int16_t* output);
Expand Down Expand Up @@ -407,7 +406,7 @@ void UpdateLstmCell(const LstmStepManager& step_info,
step_info.CellStateOffset());
}
}
#else // #if !defined(HIFI5) || defined(HIFI4)
#else // #if !defined(HIFI5) || defined(HIFI4)
template <typename ActivationType, typename WeightType, typename CellType,
typename BiasType>
void CalculateLstmGate(
Expand All @@ -424,7 +423,6 @@ void CalculateLstmGate(
CellType* fc_output_buffer, const TfLiteFusedActivation activation,
const int num_batches, const int input_dimension,
const int state_dimension) {

// RuntimeShape step_input_shape = step_info.InputShape();
// RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
// RuntimeShape step_state_shape = step_info.StateShape();
Expand All @@ -433,10 +431,10 @@ void CalculateLstmGate(
// Moved these to LstmStep function
// Check offset validity to avoid memory overflow
// TFLITE_DCHECK_LE(step_info.InputOffset() + step_input_shape.FlatSize(),
// input_shape.FlatSize());
// input_shape.FlatSize());
// TFLITE_DCHECK_LE(
// step_info.HiddenStateOffset() + step_state_shape.FlatSize(),
// recurrent_shape.FlatSize());
// step_info.HiddenStateOffset() + step_state_shape.FlatSize(),
// recurrent_shape.FlatSize());

// Input FC
FullyConnected(gate_params.input_fc_params,
Expand Down Expand Up @@ -492,14 +490,13 @@ void UpdateLstmCell(const LstmStepManager& step_info,
void UpdateLstmCell(const LstmStepManager& step_info,
TfLiteEvalTensor* cell_state,
// Gate outputs
float* forget_gate_output,
const float* input_gate_output,
float* forget_gate_output, const float* input_gate_output,
const float* cell_gate_output,
// Mul parameters
const ArithmeticParams& forget_cell_mul_params,
const ArithmeticParams& input_mul_params,
const CellStateInfo& cell_state_info, float* buffer);
#endif // #if defined(HIFI5) || defined(HIFI4)
#endif // #if defined(HIFI5) || defined(HIFI4)

// Update the hidden state of the LSTM kernel using the following formula:
// updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
Expand Down Expand Up @@ -533,8 +530,7 @@ void UpdateLstmHidden(const LstmStepManager& step_info,
step_info.HiddenStateOffset());
#else
int32_t cell_state_size = cell_state_shape.FlatSize();
Tanh(cell_state_scale_power, cell_state_data, buffer,
cell_state_size);
Tanh(cell_state_scale_power, cell_state_data, buffer, cell_state_size);
// Update the hidden state
Mul(mul_params, buffer, output_gate_output,
tflite::micro::GetTensorData<ActivationType>(hidden_state) +
Expand Down Expand Up @@ -655,13 +651,12 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
step_info.HiddenStateOffset(),
step_info.StateShape().FlatSize() * sizeof(ActivationType));
}
#else // #if !(defined(HIFI5) || defined(HIFI4))
#else // #if !(defined(HIFI5) || defined(HIFI4))
template <typename ActivationType, typename WeightType, typename CellType,
typename BiasType>
void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
LSTMKernelContents& kernel_content,
const LSTMBuffers<CellType>& buffers) {

const TfLiteEvalTensor* input =
kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();
Expand All @@ -684,11 +679,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data.forget_gate_parameters,
// Input FC
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor),
kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor),
// Recurrent FC
recurrent, // kernel_content.HiddenStateTensor(),
recurrent, // kernel_content.HiddenStateTensor(),
kernel_content.GetInternalTensor(
tflite::kLstmRecurrentToForgetWeightsTensor),
/*recurrent_bias*/ nullptr,
Expand All @@ -703,11 +698,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data.input_gate_parameters,
// Input FC
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor),
kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor),
// Recurrent FC
recurrent, // kernel_content.HiddenStateTensor(),
recurrent, // kernel_content.HiddenStateTensor(),
kernel_content.GetInternalTensor(
tflite::kLstmRecurrentToInputWeightsTensor),
/*recurrent_bias*/ nullptr,
Expand All @@ -722,11 +717,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data.cell_gate_parameters,
// Input FC
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor),
kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor),
// Recurrent FC
recurrent, // kernel_content.HiddenStateTensor(),
recurrent, // kernel_content.HiddenStateTensor(),
kernel_content.GetInternalTensor(
tflite::kLstmRecurrentToCellWeightsTensor),
/*recurrent_bias*/ nullptr,
Expand All @@ -741,22 +736,21 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
CellType* updated_input_buffer = buffers.buffer1; // reuse buffer

UpdateLstmCell(step_info, kernel_content.CellStateTensor(),
forget_gate_output, input_gate_output,
cell_gate_output,
inter_gate_params.forget_cell_mul_params,
inter_gate_params.input_mul_params,
op_data.cell_state_info, updated_input_buffer);
forget_gate_output, input_gate_output, cell_gate_output,
inter_gate_params.forget_cell_mul_params,
inter_gate_params.input_mul_params, op_data.cell_state_info,
updated_input_buffer);

/*Step3: update the hidden state */
CellType* output_gate_output = buffers.buffer1; // reuse buffer
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data.output_gate_parameters,
// Input FC
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor),
kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor),
// Recurrent FC
recurrent, // kernel_content.HiddenStateTensor(),
recurrent, // kernel_content.HiddenStateTensor(),
kernel_content.GetInternalTensor(
tflite::kLstmRecurrentToOutputWeightsTensor),
/*recurrent_bias*/ nullptr,
Expand All @@ -768,8 +762,8 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,

CellType* tanh_activated_cell_buffer = buffers.buffer0; // reuse buffer
tflite::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
step_info, kernel_content.CellStateTensor(),
recurrent, /* kernel_content.HiddenStateTensor(), */ output_gate_output,
step_info, kernel_content.CellStateTensor(), recurrent,
/* kernel_content.HiddenStateTensor(), */ output_gate_output,
inter_gate_params.output_mul_params,
op_data.cell_state_info.cell_state_scale_power,
tanh_activated_cell_buffer);
Expand All @@ -788,7 +782,7 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
step_info.HiddenStateOffset(),
step_info.StateShape().FlatSize() * sizeof(ActivationType));
}
#endif // #if !(defined(HIFI5) || defined(HIFI4))
#endif // #if !(defined(HIFI5) || defined(HIFI4))

} // namespace lstm_internal

Expand Down
Loading

0 comments on commit 02480ed

Please sign in to comment.