Skip to content

Commit

Permalink
Simplify permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
PikaCat-OuO committed Apr 14, 2024
1 parent 8514ef0 commit ad2e5c7
Showing 1 changed file with 48 additions and 73 deletions.
121 changes: 48 additions & 73 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,21 @@ using PSQTWeightType = std::int32_t;
static_assert(PSQTBuckets % 8 == 0,
"Per feature PSQT values cannot be processed at granularity lower than 8 at a time.");

#ifdef USE_AVX512F
#ifdef USE_AVX512
using vec_t = __m512i;
using psqt_vec_t = __m256i;
#define vec_load(a) _mm512_load_si512(a)
#define vec_store(a, b) _mm512_store_si512(a, b)
#define vec_add_16(a, b) \
__builtin_shufflevector(_mm256_add_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
_mm256_add_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_sub_16(a, b) \
__builtin_shufflevector(_mm256_sub_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
_mm256_sub_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_mul_16(a, b) \
__builtin_shufflevector(_mm256_mullo_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
_mm256_mullo_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_add_16(a, b) _mm512_add_epi16(a, b)
#define vec_sub_16(a, b) _mm512_sub_epi16(a, b)
#define vec_mul_16(a, b) _mm512_mullo_epi16(a, b)
#define vec_zero() _mm512_setzero_epi32()
#define vec_set_16(a) _mm512_set1_epi16(a)
#define vec_max_16(a, b) \
__builtin_shufflevector(_mm256_max_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
_mm256_max_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_min_16(a, b) \
__builtin_shufflevector(_mm256_min_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
_mm256_min_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_max_16(a, b) _mm512_max_epi16(a, b)
#define vec_min_16(a, b) _mm512_min_epi16(a, b)
// Inverse permuted at load time
#define vec_msb_pack_16(a, b) \
__builtin_shufflevector( \
_mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), 7), \
_mm256_srli_epi16(__builtin_shufflevector(b, b, 0, 1, 2, 3), 7)), \
_mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), 7), \
_mm256_srli_epi16(__builtin_shufflevector(b, b, 4, 5, 6, 7), 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
_mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7))
#define vec_load_psqt(a) _mm256_load_si256(a)
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
Expand All @@ -101,21 +71,32 @@ using psqt_vec_t = __m256i;
#define NumRegistersSIMD 16
#define MaxChunkSize 64

#elif USE_AVX512
#elif USE_AVX512F
using vec_t = __m512i;
using psqt_vec_t = __m256i;
#define vec_op(op, a, b) \
__builtin_shufflevector(op(__builtin_shufflevector(a, a, 0, 1, 2, 3), \
__builtin_shufflevector(b, b, 0, 1, 2, 3)), \
op(__builtin_shufflevector(a, a, 4, 5, 6, 7), \
__builtin_shufflevector(b, b, 4, 5, 6, 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_load(a) _mm512_load_si512(a)
#define vec_store(a, b) _mm512_store_si512(a, b)
#define vec_add_16(a, b) _mm512_add_epi16(a, b)
#define vec_sub_16(a, b) _mm512_sub_epi16(a, b)
#define vec_mul_16(a, b) _mm512_mullo_epi16(a, b)
#define vec_add_16(a, b) vec_op(_mm256_add_epi16, a, b)
#define vec_sub_16(a, b) vec_op(_mm256_sub_epi16, a, b)
#define vec_mul_16(a, b) vec_op(_mm256_mullo_epi16, a, b)
#define vec_zero() _mm512_setzero_epi32()
#define vec_set_16(a) _mm512_set1_epi16(a)
#define vec_max_16(a, b) _mm512_max_epi16(a, b)
#define vec_min_16(a, b) _mm512_min_epi16(a, b)
#define vec_max_16(a, b) vec_op(_mm256_max_epi16, a, b)
#define vec_min_16(a, b) vec_op(_mm256_min_epi16, a, b)
// Inverse permuted at load time
#define vec_msb_pack_16(a, b) \
_mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7))
__builtin_shufflevector( \
_mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 0, 1, 2, 3), 7), \
_mm256_srli_epi16(__builtin_shufflevector(b, b, 0, 1, 2, 3), 7)), \
_mm256_packs_epi16(_mm256_srli_epi16(__builtin_shufflevector(a, a, 4, 5, 6, 7), 7), \
_mm256_srli_epi16(__builtin_shufflevector(b, b, 4, 5, 6, 7), 7)), \
0, 1, 2, 3, 4, 5, 6, 7)
#define vec_load_psqt(a) _mm256_load_si256(a)
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
Expand Down Expand Up @@ -278,43 +259,37 @@ class FeatureTransformer {
return FeatureSet::HashValue ^ (OutputDimensions * 2);
}

static constexpr void order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // _mm512_packs_epi16 ordering
uint64_t tmp0, tmp1;
tmp0 = v[2], tmp1 = v[3];
v[2] = v[8], v[3] = v[9];
v[8] = v[4], v[9] = v[5];
v[4] = tmp0, v[5] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[10], v[7] = v[11];
v[10] = v[12], v[11] = v[13];
v[12] = tmp0, v[13] = tmp1;
#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
static void order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX2)
vec_t* vec = reinterpret_cast<vec_t*>(v);
vec_t vec0 = vec[0], vec1 = vec[1];
#if defined(USE_AVX512) || defined(USE_AVX512F) // _mm512_packs_epi16 ordering
vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 8, 9, 2, 3, 10, 11);
vec[1] = __builtin_shufflevector(vec0, vec1, 4, 5, 12, 13, 6, 7, 14, 15);
#else // _mm256_packs_epi16 ordering
vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5);
vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7);
#endif
#endif
}

static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering
uint64_t tmp0, tmp1;
tmp0 = v[2], tmp1 = v[3];
v[2] = v[4], v[3] = v[5];
v[4] = v[8], v[5] = v[9];
v[8] = tmp0, v[9] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[12], v[7] = v[13];
v[12] = v[10], v[13] = v[11];
v[10] = tmp0, v[11] = tmp1;
#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
static void inverse_order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX2)
vec_t* vec = reinterpret_cast<vec_t*>(v);
vec_t vec0 = vec[0], vec1 = vec[1];
#if defined(USE_AVX512) || defined(USE_AVX512F) // Inverse _mm512_packs_epi16 ordering
vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5, 8, 9, 12, 13);
vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7, 10, 11, 14, 15);
#else // Inverse _mm256_packs_epi16 ordering
vec[0] = __builtin_shufflevector(vec0, vec1, 0, 1, 4, 5);
vec[1] = __builtin_shufflevector(vec0, vec1, 2, 3, 6, 7);
#endif
#endif
}

void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) const {
#if defined(USE_AVX2)
#if defined(USE_AVX512)
#if defined(USE_AVX512) || defined(USE_AVX512F)
constexpr IndexType di = 16;
#else
constexpr IndexType di = 8;
Expand Down

0 comments on commit ad2e5c7

Please sign in to comment.