diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index db895ea74..f8a49f866 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -1174,7 +1174,12 @@ inline __m256 exp_ps_0_1(const __m256 x) { static const auto log2e = _mm256_set1_ps(v_log2e); static const auto half = _mm256_set1_ps(.5f); - const auto x1 = _mm256_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f); + static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float) + static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float) + __m256 x1 = _mm256_min_ps(x, upper_bound); + x1 = _mm256_max_ps(x1, lower_bound); + + x1 = _mm256_fmadd_ps(x1, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f); const auto z = _mm256_floor_ps(x1); const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z; diff --git a/neural_speed/core/layers/mha_dense.cpp b/neural_speed/core/layers/mha_dense.cpp index 71f20076b..b0ab5118b 100644 --- a/neural_speed/core/layers/mha_dense.cpp +++ b/neural_speed/core/layers/mha_dense.cpp @@ -74,7 +74,7 @@ bool bestla_reordered_attn_fp32_support(const attn_shape_t* params) { #endif // use avx2 and f16c on avx2 platforms // todo: check avx2 mha on sever - return false; + return !_cd->AVX512F() && _cd->AVX2(); } // kv cache sizes in bytes per layer per batch per beam for; void bestla_reordered_attn_fp32_batch_kv_info(const kv_shape_t* params, kv_cache_info_t* out) {