diff --git a/src/booster/arm/generic_kernels.cpp b/src/booster/arm/generic_kernels.cpp index 150a52b..06c4b87 100644 --- a/src/booster/arm/generic_kernels.cpp +++ b/src/booster/arm/generic_kernels.cpp @@ -173,7 +173,7 @@ template void add_relu(float* dst, const float* A, const float* B, const void vsub(float* dst, float* A, float* B, size_t len, size_t num_threads) { #pragma omp parallel for num_threads(num_threads) schedule(static) - for (int i = 0; i < len - 4; ++i) + for (int i = 0; i < len; i += 4) { float32x4_t vA = vld1q_f32(A + i); float32x4_t vB = vld1q_f32(B + i); @@ -188,7 +188,7 @@ void vsub(float* dst, float* A, float* B, size_t len, size_t num_threads) void vmul(float* dst, float* A, float* B, size_t len, size_t num_threads) { #pragma omp parallel for num_threads(num_threads) schedule(static) - for (int i = 0; i < len - 4; ++i) + for (int i = 0; i < len; i += 4) { float32x4_t vA = vld1q_f32(A + i); float32x4_t vB = vld1q_f32(B + i); @@ -459,4 +459,4 @@ void reluVecOpenmp(float* arr, int len, int nThreads) for (int i = aLen; i < len; i++) if (arr[i] < 0) arr[i] = 0; } -}; \ No newline at end of file +}; diff --git a/src/booster/avx/generic_kernels.cpp b/src/booster/avx/generic_kernels.cpp index 978d77b..cd743d4 100644 --- a/src/booster/avx/generic_kernels.cpp +++ b/src/booster/avx/generic_kernels.cpp @@ -173,7 +173,7 @@ template void add_relu(float* dst, const float* A, const float* B, const void vsub(float* dst, float* A, float* B, size_t len, size_t num_threads) { #pragma omp parallel for num_threads(num_threads) schedule(static) - for (int i = 0; i < len - 4; ++i) + for (int i = 0; i < len; i += 4) { __m128 vA = _mm_load_ps(A + i); __m128 vB = _mm_load_ps(B + i); @@ -188,7 +188,7 @@ void vsub(float* dst, float* A, float* B, size_t len, size_t num_threads) void vmul(float* dst, float* A, float* B, size_t len, size_t num_threads) { #pragma omp parallel for num_threads(num_threads) schedule(static) - for (int i = 0; i < len - 4; ++i) + for (int i = 0; i < len; i += 4) { __m128 vA = _mm_load_ps(A + i); __m128 vB = _mm_load_ps(B + i);