Skip to content

Commit

Permalink
Fix undefined behavior (pytorch#2587)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2587

For a signed integer type it is undefined behavior in C++ to left-shift bits into the leftmost position where the result would become a negative number.

Rewrite several `((long long)1) << X)` expressions to `1ULL << X` using unsigned long long to avoid this undefined behavior fixing UBSan complaints. This rewrites all instances of the pattern for consistency (even in places that wouldn't end up shifting into the highest bit position).

Reviewed By: jspark1105

Differential Revision: D57281091

fbshipit-source-id: c1be27825103916749af540c3e420edb4fb1bcb2
  • Loading branch information
MatzeB authored and facebook-github-bot committed May 20, 2024
1 parent 006833a commit 7792908
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/EmbeddingSpMDMAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ template <
typename T,
typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
static inline __mmask16 mask_from_rem(int rem) {
__mmask16 mask_rem_v = (((long long)1) << rem) - 1;
__mmask16 mask_rem_v = (1ULL << rem) - 1;
return mask_rem_v;
}

template <
typename T,
typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
static inline __mmask8 mask_from_rem(int rem) {
__mmask8 mask_rem_v = (((long long)1) << rem) - 1;
__mmask8 mask_rem_v = (1ULL << rem) - 1;
return mask_rem_v;
}

Expand Down Expand Up @@ -307,7 +307,7 @@ static inline void mymemcpy(char* src, char* dest, int len) {
}
int rem = len - i;
if (rem > 0) {
__mmask64 mask_rem_v = (((long long)1) << rem) - 1;
__mmask64 mask_rem_v = (1ULL << rem) - 1;
auto src_v = _mm512_maskz_loadu_epi8(mask_rem_v, src + i);
_mm512_mask_storeu_epi8(dest + i, mask_rem_v, src_v);
}
Expand Down
4 changes: 2 additions & 2 deletions src/FbgemmSparseDenseAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void SparseDenseMMAvx512(
int r2_rem = N - VLEN - j;
r2_rem = (r2_rem <= VLEN) ? r2_rem : (VLEN);
r2_rem = (r2_rem < 0) ? 0 : r2_rem;
__mmask16 mask_v = (((long long)1) << r2_rem) - 1;
__mmask16 mask_v = (1ULL << r2_rem) - 1;
for (int i = 0; i < M; ++i) {
__m512 c_v_r1;
__m512 c_v_r2;
Expand Down Expand Up @@ -97,7 +97,7 @@ void SparseDenseMMAvx512(
if (rem > 0) {
for (int i = 0; i < M; ++i) {
__m512 c_v;
__mmask16 mask_v = (((long long)1) << rem) - 1;
__mmask16 mask_v = (1ULL << rem) - 1;
if (accum) {
c_v = _mm512_maskz_loadu_ps(mask_v, C + i * ldc + j);
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/FbgemmSparseDenseInt8Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static inline void interleave4RowsTile(
}
} else {
int rem_int8 = N - col_start;
__mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
__mmask64 mask_int8_v = (1ULL << rem_int8) - 1;
__m512i br_v[4];
int i = 0;
for (; i < kBlocks; ++i) {
Expand Down Expand Up @@ -469,8 +469,8 @@ void SparseDenseInt8MMAvx512(
break;
}

__mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
__mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
__mmask16 mask_int32_v = (1ULL << rem_int32) - 1;
__mmask64 mask_int8_v = (1ULL << rem_int8) - 1;
for (int i = 0; i < M; ++i) {
__m512i c_v[4] = {};
if (accum || kt > 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/FbgemmSparseDenseVectorInt8Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ static inline void requantizeForMV(
}
int rem_int32 = len - i;
if (rem_int32 > 0) {
__mmask64 mask_int8_v = (((long long)1) << rem_int32) - 1;
__mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
__mmask64 mask_int8_v = (1ULL << rem_int32) - 1;
__mmask16 mask_int32_v = (1ULL << rem_int32) - 1;
__m512i x_v = _mm512_maskz_loadu_epi32(mask_int32_v, src + i);

if (!ACT_ZP_0) {
Expand Down Expand Up @@ -197,7 +197,7 @@ void SparseDenseInt8MVAvx512(

int rem = cur_row_ptr[i + 1] - r;
if (rem > 0) {
__mmask16 mask_int32_v = (((long long)1) << (rem)) - 1;
__mmask16 mask_int32_v = (1ULL << rem) - 1;
__m512i a_v =
_mm512_maskz_loadu_epi32(mask_int32_v, values + r * block_size);
__m512i b_idx = _mm512_maskz_loadu_epi32(mask_int32_v, col_idx + r);
Expand Down
60 changes: 30 additions & 30 deletions src/UtilsAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ static inline void load_with_remainders_i16(
int nrem) {
__m512i t[16];
if (nrem < 16) {
__mmask32 mask_nrem_v = (((long long)1) << nrem) - 1;
__mmask32 mask_nrem_v = (1ULL << nrem) - 1;
for (int i = 0; i < mrem; ++i) {
// mask load
t[i] = _mm512_maskz_loadu_epi16(mask_nrem_v, src + i * ld_src);
Expand Down Expand Up @@ -537,7 +537,7 @@ static inline void load_with_remainders_i8(
int nrem) {
__m512i t[16];
if (nrem < 32) {
__mmask64 mask_nrem_v = (((long long)1) << nrem) - 1;
__mmask64 mask_nrem_v = (1ULL << nrem) - 1;
for (int i = 0; i < mrem; ++i) {
// mask load
t[i] = _mm512_maskz_loadu_epi8(mask_nrem_v, src + i * ld_src);
Expand Down Expand Up @@ -566,7 +566,7 @@ static inline void store_with_remainders_i16(
int mrem,
int nrem) {
if (mrem < 16) {
__mmask32 mask_mrem_v = (((long long)1) << mrem) - 1;
__mmask32 mask_mrem_v = (1ULL << mrem) - 1;
int i = 0;

for (; i < nrem / 2 * 2; i += 2) {
Expand Down Expand Up @@ -616,7 +616,7 @@ static inline void store_with_remainders_i8(
int mrem,
int nrem) {
if (mrem < 16) {
__mmask64 mask_mrem_v = (((long long)1) << mrem) - 1;
__mmask64 mask_mrem_v = (1ULL << mrem) - 1;
int i = 0;
for (; i < nrem / 4 * 4; i += 4) {
// mask store
Expand Down Expand Up @@ -743,7 +743,7 @@ static inline void transpose_contiguous_4x16_block(
__m512i r[4];
// load
if (nrem < 16) {
__mmask16 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask16 mask_mrem_v = (1ULL << nrem) - 1;
r[0] = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
r[1] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
r[2] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 2 * ld_src);
Expand Down Expand Up @@ -790,7 +790,7 @@ static inline void transpose_contiguous_4x16_block(
int erem = nrem * 4 - i * 16;
if (erem > 0) {
// mask store
__mmask16 mask_rem_v = (((long long)1) << erem) - 1;
__mmask16 mask_rem_v = (1ULL << erem) - 1;
_mm512_mask_storeu_epi32(dst + i * 16, mask_rem_v, r[i]);
}
}
Expand All @@ -803,7 +803,7 @@ static inline void transpose_contiguous_4x32_block(
__m512i r[4], d[4];
// load
if (nrem < 32) {
__mmask32 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask32 mask_mrem_v = (1ULL << nrem) - 1;
r[0] = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
r[1] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
r[2] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 2 * ld_src);
Expand Down Expand Up @@ -844,7 +844,7 @@ static inline void transpose_contiguous_4x32_block(
int erem = nrem * 4 - i * 32;
if (erem > 0) {
// mask store
__mmask32 mask_rem_v = (((long long)1) << erem) - 1;
__mmask32 mask_rem_v = (1ULL << erem) - 1;
_mm512_mask_storeu_epi16(dst + i * 32, mask_rem_v, r[i]);
}
}
Expand All @@ -861,7 +861,7 @@ static inline void transpose_contiguous_16x4_block(
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
}
if (i * 16 < mrem * 4) {
__mmask16 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 16)) - 1;
__mmask16 mask_mrem_v = (1ULL << (mrem * 4 - i * 16)) - 1;
r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
}

Expand Down Expand Up @@ -900,7 +900,7 @@ static inline void transpose_contiguous_16x4_block(

if (mrem < 16) {
// mask store
__mmask16 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask16 mask_rem_v = (1ULL << mrem) - 1;
_mm512_mask_storeu_epi32(dst + 0 * ld_dst, mask_rem_v, d[0]);
_mm512_mask_storeu_epi32(dst + 1 * ld_dst, mask_rem_v, d[1]);
_mm512_mask_storeu_epi32(dst + 2 * ld_dst, mask_rem_v, d[2]);
Expand All @@ -926,7 +926,7 @@ static inline void transpose_contiguous_16x2_block(
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
}
if (i * 16 < mrem * 2) {
__mmask16 mask_mrem_v = (((long long)1) << (mrem * 2 - i * 16)) - 1;
__mmask16 mask_mrem_v = (1ULL << (mrem * 2 - i * 16)) - 1;
r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
}
// transpose
Expand Down Expand Up @@ -972,7 +972,7 @@ static inline void transpose_contiguous_16x2_block(

// store
if (mrem < 16) {
__mmask16 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask16 mask_rem_v = (1ULL << mrem) - 1;
// mask store
_mm512_mask_storeu_epi32(dst, mask_rem_v, d[0]);
_mm512_mask_storeu_epi32(dst + ld_dst, mask_rem_v, d[1]);
Expand All @@ -996,7 +996,7 @@ static inline void transpose_contiguous_64x4_block(
}
int erem = mrem * 4 - i * 64;
if (erem > 0) {
__mmask64 mask_mrem_v = (((long long)1) << erem) - 1;
__mmask64 mask_mrem_v = (1ULL << erem) - 1;
r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
}

Expand Down Expand Up @@ -1043,7 +1043,7 @@ static inline void transpose_contiguous_64x4_block(

// store
if (mrem < 64) {
__mmask64 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask64 mask_rem_v = (1ULL << mrem) - 1;
// mask store
_mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
_mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
Expand All @@ -1070,7 +1070,7 @@ static inline void transpose_contiguous_32x4_block(
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32));
}
if (i * 32 < mrem * 4) {
__mmask32 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 32)) - 1;
__mmask32 mask_mrem_v = (1ULL << (mrem * 4 - i * 32)) - 1;
r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
}
// transpose
Expand Down Expand Up @@ -1109,7 +1109,7 @@ static inline void transpose_contiguous_32x4_block(

if (mrem < 32) {
// mask store
__mmask32 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask32 mask_rem_v = (1ULL << mrem) - 1;
_mm512_mask_storeu_epi16(dst + 0 * ld_dst, mask_rem_v, d[0]);
_mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, d[1]);
_mm512_mask_storeu_epi16(dst + 2 * ld_dst, mask_rem_v, d[2]);
Expand All @@ -1131,7 +1131,7 @@ static inline void transpose_contiguous_2x16_block(
__m512i r0, r1;
// load
if (nrem < 16) {
__mmask16 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask16 mask_mrem_v = (1ULL << nrem) - 1;
r0 = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
r1 = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
} else {
Expand Down Expand Up @@ -1181,10 +1181,10 @@ static inline void transpose_contiguous_2x16_block(
if (nrem < 16) {
// mask store
if (nrem < 8) {
__mmask16 mask_rem_v = (((long long)1) << (nrem * 2)) - 1;
__mmask16 mask_rem_v = (1ULL << (nrem * 2)) - 1;
_mm512_mask_storeu_epi32(dst, mask_rem_v, u0);
} else {
__mmask16 mask_rem_v = (((long long)1) << ((nrem - 8) * 2)) - 1;
__mmask16 mask_rem_v = (1ULL << ((nrem - 8) * 2)) - 1;
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0);
_mm512_mask_storeu_epi32(dst + 16, mask_rem_v, u1);
}
Expand All @@ -1208,7 +1208,7 @@ static inline void transpose_contiguous_64x2_block(
}
int erem = mrem * 2 - i * 64;
if (erem > 0) {
__mmask64 mask_mrem_v = (((long long)1) << erem) - 1;
__mmask64 mask_mrem_v = (1ULL << erem) - 1;
r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
}

Expand Down Expand Up @@ -1241,7 +1241,7 @@ static inline void transpose_contiguous_64x2_block(

// store
if (mrem < 64) {
__mmask64 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask64 mask_rem_v = (1ULL << mrem) - 1;
// mask store
_mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
_mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
Expand All @@ -1260,7 +1260,7 @@ static inline void transpose_contiguous_4x64_block(
__m512i r[4], d[4];
// load
if (nrem < 64) {
__mmask64 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask64 mask_mrem_v = (1ULL << nrem) - 1;
r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
r[2] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 2 * ld_src);
Expand Down Expand Up @@ -1324,7 +1324,7 @@ static inline void transpose_contiguous_4x64_block(
}
int erem = nrem * 4 - i * 64;
if (erem > 0) {
__mmask64 mask_rem_v = (((long long)1) << erem) - 1;
__mmask64 mask_rem_v = (1ULL << erem) - 1;
_mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
}
}
Expand All @@ -1338,7 +1338,7 @@ static inline void transpose_contiguous_2x64_block(
__m512i d[2];
// load
if (nrem < 64) {
__mmask64 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask64 mask_mrem_v = (1ULL << nrem) - 1;
r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
} else {
Expand Down Expand Up @@ -1381,7 +1381,7 @@ static inline void transpose_contiguous_2x64_block(
}
int erem = nrem * 2 - i * 64;
if (erem > 0) {
__mmask64 mask_rem_v = (((long long)1) << erem) - 1;
__mmask64 mask_rem_v = (1ULL << erem) - 1;
_mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
}
}
Expand All @@ -1395,7 +1395,7 @@ static inline void transpose_contiguous_2x32_block(
__m512i d0, d1;
// load
if (nrem < 32) {
__mmask32 mask_mrem_v = (((long long)1) << nrem) - 1;
__mmask32 mask_mrem_v = (1ULL << nrem) - 1;
r0 = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
r1 = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
} else {
Expand All @@ -1412,12 +1412,12 @@ static inline void transpose_contiguous_2x32_block(

// store
if (nrem < 16) {
__mmask32 mask_rem_v = (((long long)1) << (nrem * 2)) - 1;
__mmask32 mask_rem_v = (1ULL << (nrem * 2)) - 1;
_mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
} else if (nrem == 16) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
} else if (nrem < 32) {
__mmask32 mask_rem_v = (((long long)1) << (nrem * 2 - 32)) - 1;
__mmask32 mask_rem_v = (1ULL << (nrem * 2 - 32)) - 1;
_mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
_mm512_mask_storeu_epi16(
Expand All @@ -1442,7 +1442,7 @@ static inline void transpose_contiguous_32x2_block(
}
int erem = mrem * 2 - i * 32;
if (erem > 0) {
__mmask32 mask_mrem_v = (((long long)1) << erem) - 1;
__mmask32 mask_mrem_v = (1ULL << erem) - 1;
r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
}
// transpose
Expand Down Expand Up @@ -1470,7 +1470,7 @@ static inline void transpose_contiguous_32x2_block(

// store
if (mrem < 32) {
__mmask32 mask_rem_v = (((long long)1) << mrem) - 1;
__mmask32 mask_rem_v = (1ULL << mrem) - 1;
// mask store
_mm512_mask_storeu_epi16(dst, mask_rem_v, r[0]);
_mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, r[1]);
Expand Down

0 comments on commit 7792908

Please sign in to comment.