Skip to content

Commit

Permalink
Merge pull request #78 from tum-ei-eda/fix-lstm-s8
Browse files Browse the repository at this point in the history
Fix 8-bit LSTM kernels
  • Loading branch information
PhilippvK authored Oct 28, 2024
2 parents 68e34ac + 6c1a819 commit 746d8fa
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 47 deletions.
7 changes: 5 additions & 2 deletions Source/ActivationFunctions/muriscv_nn_activation_s16.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,16 @@ muriscv_nn_status muriscv_nn_activation_s16(const int16_t *input,
break;
}

// int32_t input_multiplier = ((int32_t)3) << left_shift;
const int32_t input_multiplier = (left_shift < 0) ? 3 : 3 << left_shift;
const int32_t abs_left_shift = (left_shift < 0) ? -left_shift : 0;
const int32_t rounding = (abs_left_shift > 0) ? 1 << (abs_left_shift - 1) : 0;
// Use the LUT for sigmoid and take into account, that
// tanh(x) = 2*sigmoid(2*x) - 1
int32_t input_multiplier = ((int32_t)3) << left_shift;

for (int i = 0; i < size; ++i, input++, output++)
{
int32_t input_data = ((*input) * input_multiplier);
const int32_t input_data = ((*input) * input_multiplier + rounding) >> abs_left_shift;

uint32_t abs_input_data = input_data > 0 ? input_data : -input_data;

Expand Down
13 changes: 8 additions & 5 deletions Source/BasicMathFunctions/muriscv_nn_elementwise_mul_s16_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ muriscv_nn_status muriscv_nn_elementwise_mul_s16_s8(const int16_t *input_1_vect,
const int32_t batch_offset)

{
int32_t loop_count = block_size;
for (int i = 0; i < batch_size; i++)
{
int32_t loop_count = block_size;

/*#if defined(ARM_MATH_MVEI)
Expand Down Expand Up @@ -105,19 +107,20 @@ muriscv_nn_status muriscv_nn_elementwise_mul_s16_s8(const int16_t *input_1_vect,
loop_count -= 2;
}
#endif*/
for (int i = 0; i < loop_count; i++)
for (int j = 0; j < loop_count; j++, input_1_vect++, input_2_vect++, output++)
{
/* C = A * B */
int32_t mul_res = input_1_vect[i] * input_2_vect[i];
int32_t mul_res = (*input_1_vect) * (*input_2_vect);
mul_res = muriscv_nn_requantize(mul_res, out_mult, out_shift) + out_offset;

mul_res = CLAMP(mul_res, Q7_MAX, Q7_MIN);

output[i] = (int8_t)mul_res;
*output = (int8_t)mul_res;
}

//#endif

output += (batch_offset - 1) * block_size;
}
return MURISCV_NN_SUCCESS;
}
/**
Expand Down
40 changes: 28 additions & 12 deletions Source/FullyConnectedFunctions/muriscv_nn_vector_sum_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,18 @@ muriscv_nn_status muriscv_nn_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t lhs_offset,
const int32_t *bias_data)
{
#if defined(USE_VEXT)
if (bias_data)
{
memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t));

Check warning on line 61 in Source/FullyConnectedFunctions/muriscv_nn_vector_sum_s8.c

View workflow job for this annotation

GitHub Actions / Formatting Check (C/C++)

Source/FullyConnectedFunctions/muriscv_nn_vector_sum_s8.c:61:9 [clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling]

Call to function 'memcpy' is insecure as it does not provide security checks introduced in the C11 standard. Replace with analogous functions that support length arguments or provides boundary checks such as 'memcpy_s' in case of C11
}
else
{
memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t));

Check warning on line 65 in Source/FullyConnectedFunctions/muriscv_nn_vector_sum_s8.c

View workflow job for this annotation

GitHub Actions / Formatting Check (C/C++)

Source/FullyConnectedFunctions/muriscv_nn_vector_sum_s8.c:65:9 [clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling]

Call to function 'memset' is insecure as it does not provide security checks introduced in the C11 standard. Replace with analogous functions that support length arguments or provides boundary checks such as 'memset_s' in case of C11
}

if (lhs_offset)
{
// #if defined(USE_VEXT)
//ARM CODE. NEEDS TO BE CONVERTED TO RISCV
/*
const int32_t row_loop_cnt = vector_rows / 4;
Expand Down Expand Up @@ -135,18 +146,23 @@ muriscv_nn_status muriscv_nn_vector_sum_s8(int32_t *vector_sum_buf,
vector_sum_buf[i_row_loop_cnt] = vector_sum_0;
}
*/
return (MURISCV_NN_SUCCESS);


//return (MURISCV_NN_NO_IMPL_ERROR);

#else
(void)vector_sum_buf;
(void)vector_rows;
(void)vector_cols;
(void)vector_data;

return (MURISCV_NN_NO_IMPL_ERROR);
#endif

// #else
for (int i = 0; i < vector_rows; i++)
{
int32_t sum = 0;
for (int j = 0; j < vector_cols; j++)
{
sum += *vector_data++;
}
*vector_sum_buf++ += sum * lhs_offset;
}

// #endif
}
return (MURISCV_NN_SUCCESS);
}

/**
Expand Down
28 changes: 0 additions & 28 deletions Toolchain/download_lite.sh

This file was deleted.

0 comments on commit 746d8fa

Please sign in to comment.