Skip to content

Commit

Permalink
kernel: Refactor 32f_x2_subtract_32f kernel
Browse files Browse the repository at this point in the history
This kernel had a lot of superfluous lines that are removed now. This
should make it easier to understand the code. Also, the generic kernel
moved to the top. It serves as a reference for anyone looking into the
source code.

Signed-off-by: Johannes Demel <[email protected]>
  • Loading branch information
jdemel committed Nov 5, 2023
1 parent d5b317c commit ddef4af
Showing 1 changed file with 80 additions and 133 deletions.
213 changes: 80 additions & 133 deletions kernels/volk/volk_32f_x2_subtract_32f.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@
#include <inttypes.h>
#include <stdio.h>


#ifdef LV_HAVE_GENERIC

static inline void volk_32f_x2_subtract_32f_generic(float* cVector,
const float* aVector,
const float* bVector,
unsigned int num_points)
{
for (unsigned int number = 0; number < num_points; number++) {
*cVector++ = (*aVector++) - (*bVector++);
}
}
#endif /* LV_HAVE_GENERIC */


#ifdef LV_HAVE_AVX512F
#include <immintrin.h>

Expand All @@ -69,32 +84,23 @@ static inline void volk_32f_x2_subtract_32f_a_avx512f(float* cVector,
const float* bVector,
unsigned int num_points)
{
unsigned int number = 0;
const unsigned int sixteenthPoints = num_points / 16;

float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;

__m512 aVal, bVal, cVal;
for (; number < sixteenthPoints; number++) {

aVal = _mm512_load_ps(aPtr);
bVal = _mm512_load_ps(bPtr);
for (unsigned int number = 0; number < sixteenthPoints; number++) {
__m512 aVal = _mm512_load_ps(aVector);
__m512 bVal = _mm512_load_ps(bVector);

cVal = _mm512_sub_ps(aVal, bVal);
__m512 cVal = _mm512_sub_ps(aVal, bVal);

_mm512_store_ps(cPtr, cVal); // Store the results back into the C container
_mm512_store_ps(cVector, cVal); // Store the results back into the C container

aPtr += 16;
bPtr += 16;
cPtr += 16;
aVector += 16;
bVector += 16;
cVector += 16;
}

number = sixteenthPoints * 16;
for (; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - sixteenthPoints * 16);
}
#endif /* LV_HAVE_AVX512F */

Expand All @@ -106,32 +112,23 @@ static inline void volk_32f_x2_subtract_32f_a_avx(float* cVector,
const float* bVector,
unsigned int num_points)
{
unsigned int number = 0;
const unsigned int eighthPoints = num_points / 8;

float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;

__m256 aVal, bVal, cVal;
for (; number < eighthPoints; number++) {

aVal = _mm256_load_ps(aPtr);
bVal = _mm256_load_ps(bPtr);
for (unsigned int number = 0; number < eighthPoints; number++) {
__m256 aVal = _mm256_load_ps(aVector);
__m256 bVal = _mm256_load_ps(bVector);

cVal = _mm256_sub_ps(aVal, bVal);
__m256 cVal = _mm256_sub_ps(aVal, bVal);

_mm256_store_ps(cPtr, cVal); // Store the results back into the C container
_mm256_store_ps(cVector, cVal); // Store the results back into the C container

aPtr += 8;
bPtr += 8;
cPtr += 8;
aVector += 8;
bVector += 8;
cVector += 8;
}

number = eighthPoints * 8;
for (; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - eighthPoints * 8);
}
#endif /* LV_HAVE_AVX */

Expand All @@ -143,55 +140,27 @@ static inline void volk_32f_x2_subtract_32f_a_sse(float* cVector,
const float* bVector,
unsigned int num_points)
{
unsigned int number = 0;
const unsigned int quarterPoints = num_points / 4;

float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;

__m128 aVal, bVal, cVal;
for (; number < quarterPoints; number++) {

aVal = _mm_load_ps(aPtr);
bVal = _mm_load_ps(bPtr);
for (unsigned int number = 0; number < quarterPoints; number++) {
__m128 aVal = _mm_load_ps(aVector);
__m128 bVal = _mm_load_ps(bVector);

cVal = _mm_sub_ps(aVal, bVal);
__m128 cVal = _mm_sub_ps(aVal, bVal);

_mm_store_ps(cPtr, cVal); // Store the results back into the C container
_mm_store_ps(cVector, cVal); // Store the results back into the C container

aPtr += 4;
bPtr += 4;
cPtr += 4;
aVector += 4;
bVector += 4;
cVector += 4;
}

number = quarterPoints * 4;
for (; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - quarterPoints * 4);
}
#endif /* LV_HAVE_SSE */


#ifdef LV_HAVE_GENERIC

static inline void volk_32f_x2_subtract_32f_generic(float* cVector,
const float* aVector,
const float* bVector,
unsigned int num_points)
{
float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;
unsigned int number = 0;

for (number = 0; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
}
#endif /* LV_HAVE_GENERIC */


#ifdef LV_HAVE_NEON
#include <arm_neon.h>

Expand All @@ -200,27 +169,23 @@ static inline void volk_32f_x2_subtract_32f_neon(float* cVector,
const float* bVector,
unsigned int num_points)
{
float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;
unsigned int number = 0;
unsigned int quarter_points = num_points / 4;

float32x4_t a_vec, b_vec, c_vec;

for (number = 0; number < quarter_points; number++) {
a_vec = vld1q_f32(aPtr);
b_vec = vld1q_f32(bPtr);
c_vec = vsubq_f32(a_vec, b_vec);
vst1q_f32(cPtr, c_vec);
aPtr += 4;
bPtr += 4;
cPtr += 4;
}
const unsigned int quarterPoints = num_points / 4;

for (unsigned int number = 0; number < quarterPoints; number++) {
float32x4_t a_vec = vld1q_f32(aVector);
float32x4_t b_vec = vld1q_f32(bVector);

float32x4_t c_vec = vsubq_f32(a_vec, b_vec);

for (number = quarter_points * 4; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
vst1q_f32(cVector, c_vec);

aVector += 4;
bVector += 4;
cVector += 4;
}

volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - quarterPoints * 4);
}
#endif /* LV_HAVE_NEON */

Expand Down Expand Up @@ -258,32 +223,23 @@ static inline void volk_32f_x2_subtract_32f_u_avx512f(float* cVector,
const float* bVector,
unsigned int num_points)
{
unsigned int number = 0;
const unsigned int sixteenthPoints = num_points / 16;

float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;

__m512 aVal, bVal, cVal;
for (; number < sixteenthPoints; number++) {
for (unsigned int number = 0; number < sixteenthPoints; number++) {
__m512 aVal = _mm512_loadu_ps(aVector);
__m512 bVal = _mm512_loadu_ps(bVector);

aVal = _mm512_loadu_ps(aPtr);
bVal = _mm512_loadu_ps(bPtr);
__m512 cVal = _mm512_sub_ps(aVal, bVal);

cVal = _mm512_sub_ps(aVal, bVal);
_mm512_storeu_ps(cVector, cVal); // Store the results back into the C container

_mm512_storeu_ps(cPtr, cVal); // Store the results back into the C container

aPtr += 16;
bPtr += 16;
cPtr += 16;
aVector += 16;
bVector += 16;
cVector += 16;
}

number = sixteenthPoints * 16;
for (; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - sixteenthPoints * 16);
}
#endif /* LV_HAVE_AVX512F */

Expand All @@ -296,32 +252,23 @@ static inline void volk_32f_x2_subtract_32f_u_avx(float* cVector,
const float* bVector,
unsigned int num_points)
{
unsigned int number = 0;
const unsigned int eighthPoints = num_points / 8;

float* cPtr = cVector;
const float* aPtr = aVector;
const float* bPtr = bVector;

__m256 aVal, bVal, cVal;
for (; number < eighthPoints; number++) {
for (unsigned int number = 0; number < eighthPoints; number++) {
__m256 aVal = _mm256_loadu_ps(aVector);
__m256 bVal = _mm256_loadu_ps(bVector);

aVal = _mm256_loadu_ps(aPtr);
bVal = _mm256_loadu_ps(bPtr);
__m256 cVal = _mm256_sub_ps(aVal, bVal);

cVal = _mm256_sub_ps(aVal, bVal);
_mm256_storeu_ps(cVector, cVal); // Store the results back into the C container

_mm256_storeu_ps(cPtr, cVal); // Store the results back into the C container

aPtr += 8;
bPtr += 8;
cPtr += 8;
aVector += 8;
bVector += 8;
cVector += 8;
}

number = eighthPoints * 8;
for (; number < num_points; number++) {
*cPtr++ = (*aPtr++) - (*bPtr++);
}
volk_32f_x2_subtract_32f_generic(
cVector, aVector, bVector, num_points - eighthPoints * 8);
}
#endif /* LV_HAVE_AVX */

Expand Down

0 comments on commit ddef4af

Please sign in to comment.