diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 0a104c14..c7381137 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -48,51 +48,21 @@ using PSQTWeightType = std::int32_t; static_assert(PSQTBuckets % 8 == 0, "Per feature PSQT values cannot be processed at granularity lower than 8 at a time."); -#ifdef USE_AVX512F +#ifdef USE_AVX512 using vec_t = __m512i; using psqt_vec_t = __m256i; #define vec_load(a) _mm512_load_si512(a) #define vec_store(a, b) _mm512_store_si512(a, b) - #define vec_add_16(a, b) \ - __builtin_shufflevector(_mm256_add_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ - __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ - _mm256_add_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ - __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) - #define vec_sub_16(a, b) \ - __builtin_shufflevector(_mm256_sub_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ - __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ - _mm256_sub_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ - __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) - #define vec_mul_16(a, b) \ - __builtin_shufflevector(_mm256_mullo_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ - __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ - _mm256_mullo_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ - __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) + #define vec_add_16(a, b) _mm512_add_epi16(a, b) + #define vec_sub_16(a, b) _mm512_sub_epi16(a, b) + #define vec_mul_16(a, b) _mm512_mullo_epi16(a, b) #define vec_zero() _mm512_setzero_epi32() #define vec_set_16(a) _mm512_set1_epi16(a) - #define vec_max_16(a, b) \ - __builtin_shufflevector(_mm256_max_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ - __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ - _mm256_max_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ - __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) - #define vec_min_16(a, b) \ - __builtin_shufflevector(_mm256_min_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ - __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ - _mm256_min_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ - __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) + #define vec_max_16(a, b) _mm512_max_epi16(a, b) + #define vec_min_16(a, b) _mm512_min_epi16(a, b) // Inverse permuted at load time #define vec_msb_pack_16(a, b) \ - __builtin_shufflevector( \ - _mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), 7), \ - _mm256_srli_epi16(__builtin_shufflevector(b, b, 0, 1, 2, 3), 7)), \ - _mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), 7), \ - _mm256_srli_epi16(__builtin_shufflevector(b, b, 4, 5, 6, 7), 7)), \ - 0, 1, 2, 3, 4, 5, 6, 7) + _mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7)) #define vec_load_psqt(a) _mm256_load_si256(a) #define vec_store_psqt(a, b) _mm256_store_si256(a, b) #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) @@ -101,21 +71,32 @@ using psqt_vec_t = __m256i; #define NumRegistersSIMD 16 #define MaxChunkSize 64 -#elif USE_AVX512 +#elif USE_AVX512F using vec_t = __m512i; using psqt_vec_t = __m256i; + #define vec_op(op, a, b) \ + __builtin_shufflevector(op(__builtin_shufflevector(a, a, 0, 1, 2, 3), \ + __builtin_shufflevector(b, b, 0, 1, 2, 3)), \ + op(__builtin_shufflevector(a, a, 4, 5, 6, 7), \ + __builtin_shufflevector(b, b, 4, 5, 6, 7)), \ + 0, 1, 2, 3, 4, 5, 6, 7) #define vec_load(a) _mm512_load_si512(a) #define vec_store(a, b) _mm512_store_si512(a, b) - #define vec_add_16(a, b) _mm512_add_epi16(a, b) - #define vec_sub_16(a, b) _mm512_sub_epi16(a, b) - #define vec_mul_16(a, b) _mm512_mullo_epi16(a, b) + #define vec_add_16(a, b) vec_op(_mm256_add_epi16, a, b) + #define vec_sub_16(a, b) vec_op(_mm256_sub_epi16, a, b) + #define vec_mul_16(a, b) vec_op(_mm256_mullo_epi16, a, b) #define vec_zero() _mm512_setzero_epi32() #define vec_set_16(a) _mm512_set1_epi16(a) - #define vec_max_16(a, b) _mm512_max_epi16(a, b) - #define vec_min_16(a, b) _mm512_min_epi16(a, b) + #define vec_max_16(a, b) vec_op(_mm256_max_epi16, a, b) + #define vec_min_16(a, b) vec_op(_mm256_min_epi16, a, b) // Inverse permuted at load time #define vec_msb_pack_16(a, b) \ - _mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7)) + __builtin_shufflevector( \ + _mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), 7), \ + _mm256_srli_epi16(__builtin_shufflevector(b, b, 0, 1, 2, 3), 7)), \ + _mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), 7), \ + _mm256_srli_epi16(__builtin_shufflevector(b, b, 4, 5, 6, 7), 7)), \ + 0, 1, 2, 3, 4, 5, 6, 7) #define vec_load_psqt(a) _mm256_load_si256(a) #define vec_store_psqt(a, b) _mm256_store_si256(a, b) #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) @@ -278,43 +259,37 @@ class FeatureTransformer { return FeatureSet::HashValue ^ (OutputDimensions * 2); } - static constexpr void order_packs([[maybe_unused]] uint64_t* v) { -#if defined(USE_AVX512) // _mm512_packs_epi16 ordering - uint64_t tmp0, tmp1; - tmp0 = v[2], tmp1 = v[3]; - v[2] = v[8], v[3] = v[9]; - v[8] = v[4], v[9] = v[5]; - v[4] = tmp0, v[5] = tmp1; - tmp0 = v[6], tmp1 = v[7]; - v[6] = v[10], v[7] = v[11]; - v[10] = v[12], v[11] = v[13]; - v[12] = tmp0, v[13] = tmp1; -#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering - std::swap(v[2], v[4]); - std::swap(v[3], v[5]); + static void order_packs([[maybe_unused]] uint64_t* v) { +#if defined(USE_AVX2) + vec_t* vec = reinterpret_cast(v); + vec_t vec0 = vec[0], vec1 = vec[1]; + #if defined(USE_AVX512) || defined(USE_AVX512F) // _mm512_packs_epi16 ordering + vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 8, 9, 2, 3, 10, 11); + vec[1] = __builtin_shufflevector(vec0, vec1, 4, 5, 12, 13, 6, 7, 14, 15); + #else // _mm256_packs_epi16 ordering + vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5); + vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7); + #endif #endif } - static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) { -#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering - uint64_t tmp0, tmp1; - tmp0 = v[2], tmp1 = v[3]; - v[2] = v[4], v[3] = v[5]; - v[4] = v[8], v[5] = v[9]; - v[8] = tmp0, v[9] = tmp1; - tmp0 = v[6], tmp1 = v[7]; - v[6] = v[12], v[7] = v[13]; - v[12] = v[10], v[13] = v[11]; - v[10] = tmp0, v[11] = tmp1; -#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering - std::swap(v[2], v[4]); - std::swap(v[3], v[5]); + static void inverse_order_packs([[maybe_unused]] uint64_t* v) { +#if defined(USE_AVX2) + vec_t* vec = reinterpret_cast(v); + vec_t vec0 = vec[0], vec1 = vec[1]; + #if defined(USE_AVX512) || defined(USE_AVX512F) // Inverse _mm512_packs_epi16 ordering + vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5, 8, 9, 12, 13); + vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7, 10, 11, 14, 15); + #else // Inverse _mm256_packs_epi16 ordering + vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5); + vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7); + #endif #endif } void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) const { #if defined(USE_AVX2) - #if defined(USE_AVX512) + #if defined(USE_AVX512) || defined(USE_AVX512F) constexpr IndexType di = 16; #else constexpr IndexType di = 8;