diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 64e0e2c9fc..3acb2386d3 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -397,7 +397,7 @@ def rowwise_adagrad() -> None: momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); } - multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); + multiplier = shfl_sync(multiplier, 0); """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; @@ -474,8 +474,8 @@ def rowwise_weighted_adagrad() -> None: multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); correction = 1.0 - multiplier * weight_decay; } - multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); - correction = __shfl_sync(0xFFFFFFFF, correction, 0); + multiplier = shfl_sync(multiplier, 0); + correction = shfl_sync(correction, 0); """ split_weight_update_cpu = """ // weight_decay not supported for cpu version @@ -636,7 +636,7 @@ def partial_rowwise_lamb() -> None: m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; momentum2[idx] = m2; } - m2 = __shfl_sync(0xFFFFFFFF, m2, 0); + m2 = shfl_sync(m2, 0); at::acc_type m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); at::acc_type weight_sum_sq = 0.0; @@ -772,7 +772,7 @@ def partial_rowwise_adam() -> None: momentum2[idx] = v_t; v_hat_t = v_t / (1.0 - powf(beta2, iter)); } - v_hat_t = __shfl_sync(0xFFFFFFFF, v_hat_t, 0); + v_hat_t = shfl_sync(v_hat_t, 0); """ split_weight_update = """ diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 147456c791..151f3c0248 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -111,7 +111,11 @@ class SplitLookupFunction_Dense_Op ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; ctx->saved_data["pooling_mode"] = pooling_mode; +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; +#else constexpr int32_t BT_block_size = 32; +#endif if (!indice_weights.has_value()) { return {dense_embedding_codegen_forward_unweighted_cuda( dev_weights, @@ -159,8 +163,13 @@ class SplitLookupFunction_Dense_Op TORCH_CHECK(grad_outputs.size() == 1); +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; + constexpr int32_t max_segment_length_per_warp = 64; +#else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; +#endif using torch::autograd::Variable; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index 8d78e23ab1..805936ac67 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -189,7 +189,11 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : {% endfor %} {% if not nobag %} +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; +#else constexpr int32_t BT_block_size = 32; +#endif if (!indice_weights) { return {split_embedding_codegen_forward_unweighted_cuda( dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, @@ -257,8 +261,13 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : TORCH_CHECK(grad_outputs.size() == 1); +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; + constexpr int32_t max_segment_length_per_warp = 64; +#else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; +#endif using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index dc9ccddc4c..dbe82a0896 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -103,9 +103,9 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); + int64_t idx_j = shfl_sync(idx, j); {% if not dense %} - int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); + int32_t cache_idx_j = shfl_sync(cache_idx, j); {% endif %} at::acc_type grad_indice_weight = 0.0; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index df0d4d53bb..a754f2d55b 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -189,13 +189,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} - int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); - int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); + int32_t b_j = shfl_sync(b, j); + int32_t D_start_j = shfl_sync(D_start, j); {% else %} - int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); + int32_t l_j = shfl_sync(l, j); {% endif %} + {% if weighted %} - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread @@ -546,13 +547,13 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} - int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); - int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); + int32_t b_j = shfl_sync(b, j); + int32_t D_start_j = shfl_sync(D_start, j); {% else %} - int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); + int32_t l_j = shfl_sync(l, j); {% endif %} {% if weighted %} - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread @@ -757,7 +758,12 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // V100: 96 KB; A100: 160 KB. int max_shared_bytes = 0; +#ifndef __HIP_PLATFORM_HCC__ cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device()); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); int shared_kb = max_shared_bytes >> 10; // V100: 64 KB; A100: 96 KB. @@ -958,6 +964,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // over 48 KB per block are architecture-specific, as such they // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". + +#ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} @@ -970,7 +978,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); + // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} emb_t, @@ -980,7 +990,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ scalar_t, {% endif %} {{ kMaxVecsPerThread }}> - // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] <<) * 4 * kWarpSize * @@ -1031,6 +1040,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} {{ args.split_kernel_arg_constructors | join(", ") }}); C10_CUDA_KERNEL_LAUNCH_CHECK(); +#ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} @@ -1043,6 +1053,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/embedding_bounds_check.cu index e45356f01f..dc114dded4 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/embedding_bounds_check.cu @@ -69,7 +69,8 @@ __global__ void bounds_check_indices_kernel( } auto L = indices_end - indices_start; - for (auto i = threadIdx.x; i < L; i += fbgemm_gpu::kWarpSize) { + for (index_t i = (index_t)threadIdx.x; i < L; + i += (index_t)fbgemm_gpu::kWarpSize) { auto idx = indices[indices_start + i]; if (idx == -1) { // -1 indicates pruned rows. diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index 3edb340f6a..85a08e153d 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -309,7 +309,11 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i } // equivalent to fence + wait. cp_async_wait<0>(); +#ifdef __HIP_PLATFORM_HCC__ + __syncthreads(); +#else __syncwarp(); +#endif for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -507,9 +511,21 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{ found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } +#ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); + if (__any(found)) { +#else if (__any_sync(subwarp_mask, found)) { +#endif break; +#ifdef __HIP_PLATFORM_HCC__ + } else if (__any(empty)) { +#else } else if (__any_sync(subwarp_mask, empty)) { +#endif dense_indices[indices_start + l_start + subwarp_id] = -1; break; } diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index fac7f4d081..ff5f47bac4 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -127,16 +127,16 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba at::acc_type idx_weight = l < L ? indice_weights[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); + int64_t idx_j = shfl_sync(idx, j); {% if nobag %} int64_t output_j = indices_start + l_start + j; {% endif %} {% if not dense %} - int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); + int32_t cache_idx_j = shfl_sync(cache_idx, j); {% endif %} {% if weighted %} - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} {% if not dense %} diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index ac43148926..9129439a7c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -19,7 +19,11 @@ namespace fbgemm_gpu { #define DEVICE_INLINE __device__ inline __attribute__((always_inline)) // Warp size +#ifdef __HIP_PLATFORM_HCC__ +static constexpr int32_t kWarpSize = 64; +#else static constexpr int32_t kWarpSize = 32; +#endif // Max thread num in one thread block static constexpr int32_t kMaxThreads = 1024; static constexpr float kQParamEps = 1e-8f; @@ -36,7 +40,12 @@ struct Half4 { half2 b; __device__ inline void store(at::Half* p) { -#if CUDA_VERSION >= 9000 +#ifdef __HIP_PLATFORM_HCC__ + p[0] = __low2half(a); + p[1] = __high2half(a); + p[2] = __low2half(b); + p[3] = __high2half(b); +#elif CUDA_VERSION >= 9000 #ifndef __HALF2_TO_UI // cuda_fp16.hpp doesn't export this @@ -79,6 +88,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -97,6 +112,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE void store(float* p) { @@ -173,6 +189,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -191,6 +213,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE Vec4T(const float* p) { @@ -235,6 +258,12 @@ struct Vec4T { } DEVICE_INLINE static void copy(const at::Half* src, at::Half* dst) { +#ifdef __HIP_PLATFORM_HCC__ + dst[0] = src[0]; + dst[1] = src[1]; + dst[2] = src[2]; + dst[3] = src[3]; +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -251,6 +280,7 @@ struct Vec4T { : "l"(dst), "r"(__HALF2_TO_UI(out.a)), "r"(__HALF2_TO_UI(out.b))); #else asm("st.v2.u32 [%0], {%1, %2};" : : "l"(dst), "r"(out.a.x), "r"(out.b.x)); +#endif #endif } @@ -305,6 +335,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -323,6 +359,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE Vec4T(const float* p) { @@ -406,10 +443,19 @@ DEVICE_INLINE Vec4T vec4_acc( template DEVICE_INLINE T shfl_xor(const T val, int laneMask, int width = kWarpSize) { -#if CUDA_VERSION >= 9000 +#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000 + return __shfl_xor(val, laneMask, width); +#else return __shfl_xor_sync(0xffffffff, val, laneMask, width); +#endif +} + +template +DEVICE_INLINE T shfl_sync(const T val, int srcLane = 0, int width = kWarpSize) { +#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000 + return __shfl(val, srcLane, width); #else - return __shfl_xor(val, laneMask, width); + return __shfl_sync(0xffffffff, val, srcLane, width); #endif } @@ -446,7 +492,7 @@ stochastic_rounding_scalar_uint8(float x, uint32_t random_bits) { // noise.F in [1, 2] noise.F = noise.F - 1.5; // noise.F in [-0.5, 0.5] - return std::lrintf(x + noise.F); + return lrintf(x + noise.F); } // This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state @@ -517,10 +563,12 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; - v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); - v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); - v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); - v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); + v.a = __halves2half2( + stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2( + stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); v.store(output); } @@ -532,10 +580,12 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; - v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); - v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); - v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); - v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); + v.a = __halves2half2( + stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2( + stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); v.store(output); } @@ -588,10 +638,10 @@ template <> DEVICE_INLINE void nearest_rounding_vector(uint8_t* output, Vec4T value, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output[0] = std::lrintf((value.acc.x - qparams.y) * inv_scale); - output[1] = std::lrintf((value.acc.y - qparams.y) * inv_scale); - output[2] = std::lrintf((value.acc.z - qparams.y) * inv_scale); - output[3] = std::lrintf((value.acc.w - qparams.y) * inv_scale); + output[0] = lrintf((value.acc.x - qparams.y) * inv_scale); + output[1] = lrintf((value.acc.y - qparams.y) * inv_scale); + output[2] = lrintf((value.acc.z - qparams.y) * inv_scale); + output[3] = lrintf((value.acc.w - qparams.y) * inv_scale); } template <> @@ -600,10 +650,10 @@ DEVICE_INLINE void nearest_rounding_vector( Vec4T value, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output[0] = std::lrintf((value.acc.x - qparams.y) * inv_scale); - output[1] = std::lrintf((value.acc.y - qparams.y) * inv_scale); - output[2] = std::lrintf((value.acc.z - qparams.y) * inv_scale); - output[3] = std::lrintf((value.acc.w - qparams.y) * inv_scale); + output[0] = lrintf((value.acc.x - qparams.y) * inv_scale); + output[1] = lrintf((value.acc.y - qparams.y) * inv_scale); + output[2] = lrintf((value.acc.z - qparams.y) * inv_scale); + output[3] = lrintf((value.acc.w - qparams.y) * inv_scale); } template <> @@ -791,7 +841,7 @@ struct SharedMemory>> { // Return if the address is aligned to the type (mainly for Vec4T). template DEVICE_INLINE bool is_aligned(const void* ptr) { - auto iptr = reinterpret_cast(ptr); + auto iptr = reinterpret_cast(ptr); return !(iptr % alignof(T)); } @@ -879,8 +929,8 @@ __device__ float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { qparams.x = (local_max - local_min) / 255.0f; qparams.y = local_min; } - qparams.x = __shfl_sync(0xFFFFFFFF, qparams.x, 0); - qparams.y = __shfl_sync(0xFFFFFFFF, qparams.y, 0); + qparams.x = shfl_sync(qparams.x, 0); + qparams.y = shfl_sync(qparams.y, 0); return qparams; } @@ -1213,7 +1263,7 @@ struct VecNT<1> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1267,8 +1317,8 @@ struct VecNT<2> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.y - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.y - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1339,10 +1389,10 @@ struct VecNT<4> { *reinterpret_cast(output_ptr + 0) = v.x; *reinterpret_cast(output_ptr + 2) = v.y; } else { - *(output_ptr + 0) = val.vals[0].x; - *(output_ptr + 1) = val.vals[0].y; - *(output_ptr + 2) = val.vals[1].x; - *(output_ptr + 3) = val.vals[1].y; + *(output_ptr + 0) = __low2half(val.vals[0]); + *(output_ptr + 1) = __high2half(val.vals[0]); + *(output_ptr + 2) = __low2half(val.vals[1]); + *(output_ptr + 3) = __high2half(val.vals[1]); } } @@ -1352,10 +1402,10 @@ struct VecNT<4> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.y - qparams.y) * inv_scale); - output_ptr[2] = std::lrintf((acc.z - qparams.y) * inv_scale); - output_ptr[3] = std::lrintf((acc.w - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.y - qparams.y) * inv_scale); + output_ptr[2] = lrintf((acc.z - qparams.y) * inv_scale); + output_ptr[3] = lrintf((acc.w - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1444,14 +1494,14 @@ struct VecNT<8> { *reinterpret_cast(output_ptr + 4) = v.z; *reinterpret_cast(output_ptr + 6) = v.w; } else { - *(output_ptr + 0) = val.vals[0].x; - *(output_ptr + 1) = val.vals[0].y; - *(output_ptr + 2) = val.vals[1].x; - *(output_ptr + 3) = val.vals[1].y; - *(output_ptr + 4) = val.vals[2].x; - *(output_ptr + 5) = val.vals[2].y; - *(output_ptr + 6) = val.vals[3].x; - *(output_ptr + 7) = val.vals[3].y; + *(output_ptr + 0) = __low2half(val.vals[0]); + *(output_ptr + 1) = __high2half(val.vals[0]); + *(output_ptr + 2) = __low2half(val.vals[1]); + *(output_ptr + 3) = __high2half(val.vals[1]); + *(output_ptr + 4) = __low2half(val.vals[2]); + *(output_ptr + 5) = __high2half(val.vals[2]); + *(output_ptr + 6) = __low2half(val.vals[3]); + *(output_ptr + 7) = __high2half(val.vals[3]); } } @@ -1461,14 +1511,14 @@ struct VecNT<8> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.vals[0].x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.vals[0].y - qparams.y) * inv_scale); - output_ptr[2] = std::lrintf((acc.vals[0].z - qparams.y) * inv_scale); - output_ptr[3] = std::lrintf((acc.vals[0].w - qparams.y) * inv_scale); - output_ptr[4] = std::lrintf((acc.vals[1].x - qparams.y) * inv_scale); - output_ptr[5] = std::lrintf((acc.vals[1].y - qparams.y) * inv_scale); - output_ptr[6] = std::lrintf((acc.vals[1].z - qparams.y) * inv_scale); - output_ptr[7] = std::lrintf((acc.vals[1].w - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.vals[0].x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.vals[0].y - qparams.y) * inv_scale); + output_ptr[2] = lrintf((acc.vals[0].z - qparams.y) * inv_scale); + output_ptr[3] = lrintf((acc.vals[0].w - qparams.y) * inv_scale); + output_ptr[4] = lrintf((acc.vals[1].x - qparams.y) * inv_scale); + output_ptr[5] = lrintf((acc.vals[1].y - qparams.y) * inv_scale); + output_ptr[6] = lrintf((acc.vals[1].z - qparams.y) * inv_scale); + output_ptr[7] = lrintf((acc.vals[1].w - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh index 610a3fe069..3c0608919b 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh @@ -6,6 +6,10 @@ */ #pragma once +#ifdef __HIP_PLATFORM_HCC__ +#define HIPCUB_ARCH 1 +#endif + #include // clang-format off diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index 668989d2f2..99f5686641 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -41,8 +41,8 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... - K otherK = shfl_xor(k, 2 * L - 1); - V otherV = shfl_xor(v, 2 * L - 1); + K otherK = fbgemm_gpu::shfl_xor(k, 2 * L - 1); + V otherV = fbgemm_gpu::shfl_xor(v, 2 * L - 1); // Whether we are the lesser thread in the exchange bool small = !(laneId & L); @@ -64,8 +64,8 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { #pragma unroll for (int32_t stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { - K otherK = shfl_xor(k, stride); - V otherV = shfl_xor(v, stride); + K otherK = fbgemm_gpu::shfl_xor(k, stride); + V otherV = fbgemm_gpu::shfl_xor(v, stride); // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); @@ -86,7 +86,11 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { template struct BitonicSort { static inline __device__ void sort(K k[1], V v[1]) { +#ifdef __HIP_PLATFORM_HCC__ + static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size"); +#else static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size"); +#endif warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index f839825dad..3a0a1096f4 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -279,11 +279,14 @@ void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { device_guard.set_index(cuda_device_index); +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. AT_CUDA_CHECK(cudaMemAdvise( ptr, size_bytes, static_cast(cudaMemoryAdvise), hint_device)); +#endif return; } @@ -348,6 +351,8 @@ Tensor uvm_to_cpu_clone(Tensor t) { return cpu_clone; } +#ifndef __HIP_PLATFORM_HCC__ +// FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. FBGEMM_GPU_ENUM_GLOGAL(uvm) FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ @@ -358,5 +363,6 @@ FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetAccessedBy), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetAccessedBy), } FBGEMM_GPU_ENUM_REGISTER_END +#endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/cumem_utils_host.cpp b/fbgemm_gpu/src/cumem_utils_host.cpp index 1cb0c9df8f..739b6f990d 100644 --- a/fbgemm_gpu/src/cumem_utils_host.cpp +++ b/fbgemm_gpu/src/cumem_utils_host.cpp @@ -39,7 +39,10 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); +#endif } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -64,7 +67,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "uvm_mem_advice_dont_fork(Tensor t) -> ()", TORCH_FN(uvm_mem_advice_dont_fork)); +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); +#endif } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index 0a1a42d5a9..04ec601ea1 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -15,6 +15,9 @@ #include #include +// FIXME: Enable merge_pooled_embeddings for HIP. +// AMD GPUs don't seem to have nvml equivalent library support. +#ifndef __HIP_PLATFORM_HCC__ #include #include @@ -406,3 +409,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]"); DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device); } +#endif diff --git a/fbgemm_gpu/src/quantize_ops.cu b/fbgemm_gpu/src/quantize_ops.cu index 33946a7c5f..333a7912ef 100644 --- a/fbgemm_gpu/src/quantize_ops.cu +++ b/fbgemm_gpu/src/quantize_ops.cu @@ -6,7 +6,9 @@ */ #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include "fbgemm_gpu/quantize_ops.cuh" #include "fbgemm_gpu/sparse_ops_utils.h" @@ -44,7 +46,7 @@ __global__ inline void _float_to_fused8bitrowwise_cuda_kernel( const auto inverse_scale = 255.0f / (range + kEpsilon); for (std::size_t col = 0; col < ncols; ++col) { output_row[col] = - std::lrintf((input_row[col] - minimum_element) * inverse_scale); + lrintf((input_row[col] - minimum_element) * inverse_scale); } } } @@ -71,8 +73,15 @@ __global__ inline void _get_8bit_qparam_cuda_kernel( const int output_columns = ncols_aligned + 2 * sizeof(float); // starting values for future reductions +#ifdef __HIP_PLATFORM_HCC__ +#define HIPRT_INF_F __int_as_float(0x7f800000) + float minimum_element = HIPRT_INF_F; + float maximum_element = -HIPRT_INF_F; +#undef HIPRT_INF_F +#else float minimum_element = CUDART_INF_F; float maximum_element = -CUDART_INF_F; +#endif // always a power of 2 up to size 32. Multiple rows can share the same warp // when smaller than 32. @@ -145,7 +154,7 @@ __global__ inline void _compute_8bit_quantize_cuda_kernel( // TODO: lift range_list into shared memory. However, when nrows is large, // it might exceed the size of shared memory. const auto inverse_scale = 255.0f / (range_list[row] + kEpsilon); - output_addr[0] = std::lrintf((input[input_idx] - bias) * inverse_scale); + output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale); } } } @@ -222,8 +231,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( std::uint8_t quantized = QUANTIZE_OPS_MAX( 0, QUANTIZE_OPS_MIN( - static_cast( - std::lrintf((X - minimum_element) * inverse_scale)), + static_cast(lrintf((X - minimum_element) * inverse_scale)), static_cast((1 << bit_rate) - 1))); if (col % num_elem_per_byte == 0) { diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 8770b32594..45ae079ef6 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -882,7 +882,7 @@ __global__ void reorder_batched_ad_indices_kernel( const int32_t output_segment_start = reordered_cat_ad_offsets[output_segment_offset_start]; - for (auto i = threadIdx.x; i < input_segment_end - input_segment_start; + for (int32_t i = threadIdx.x; i < input_segment_end - input_segment_start; i += blockDim.x) { reordered_cat_ad_indices[output_segment_start + i] = cat_ad_indices[input_segment_start + i]; diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 62185a0c80..34a55d5989 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -264,9 +264,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( // hash_offset < 0 for non-caching tables for (int32_t j = 0; j < kWarpSize; ++j) { - auto indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - int64_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); + auto indices_start_warp = shfl_sync(indices_start, j); + int32_t L_warp = shfl_sync(L, j); + int64_t hash_offset_warp = shfl_sync(hash_offset, j); if (hash_offset_warp >= 0) { for (int32_t i = lane_id; i < L_warp; i += kWarpSize) { auto idx = __ldg(&indices[indices_start_warp + i]); @@ -465,7 +465,15 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( lru_state[cache_set][slot] = time_stamp; } +#ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { cache_sets[n] = cache_set; } @@ -605,9 +613,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lru_cost = - __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lru_cost = shfl_sync(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } @@ -623,7 +630,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); + current_idx = shfl_sync(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -919,9 +926,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lru_cost = - __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lru_cost = shfl_sync(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } @@ -942,7 +948,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); + current_idx = shfl_sync(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -1213,7 +1219,15 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( << kLFUCounterBits); // invalid index, used as sentinel } +#ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { // sort so the highest LFUs come first in the segment. // assume lfu_state[idx] <= 2^40 - 1 and cache_set < 2^24 -1 @@ -1360,9 +1374,8 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lfu_cost = - __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1386,7 +1399,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); + current_idx = shfl_sync(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; int64_t idx_current = current_idx - cache_hash_size_cumsum[t_current]; int64_t weights_offset_current = weights_offsets[t_current]; @@ -1698,9 +1711,8 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lfu_cost = - __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); index_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1729,7 +1741,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); + current_idx = shfl_sync(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; SparseType weight_ty_current = static_cast(weights_tys[t_current]); @@ -1919,7 +1931,15 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( if (found) { lxu_cache_locations[n] = cache_set * kWarpSize + slot; } +#ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { lxu_cache_locations[n] = kCacheLocationMissing; } diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils.cu index ca8c8135bd..46f14b6453 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cu +++ b/fbgemm_gpu/src/split_embeddings_utils.cu @@ -36,10 +36,10 @@ __global__ void linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t b_t_warp = __shfl_sync(0xFFFFFFFF, b_t, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); + index_t indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + int32_t b_t_warp = fbgemm_gpu::shfl_sync(b_t, j); + int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); infos[indices_start_warp + i] = b_t_warp; @@ -70,10 +70,10 @@ __global__ void nobag_linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t t_warp = __shfl_sync(0xFFFFFFFF, t, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); + index_t indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + int32_t t_warp = fbgemm_gpu::shfl_sync(t, j); + int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); int64_t l_t = (indices_start_warp + i) * T + t_warp;