Skip to content

Commit

Permalink
Fix bf16 support issues (#2238)
Browse files Browse the repository at this point in the history
Summary:

For bf16 related cuda code, we have the following macro to distinguish between v100 vs. a100 (pre-a100 cuda/NV GPU doesn't support BF16):
```
#if !(                                                  \
    ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
```
macro.

For AMD GPU (rocm), it will lead to always false. However, on the MI250 / MI300 GPU we have in house, they have BF16 supports. We re-enable BF16 for RoCM related usages.

Reviewed By: houseroad, jiawenliu64

Differential Revision: D52438898
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 28, 2023
1 parent 857242b commit 90ecc97
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
{% for params in emb_weight_type.template_params %}

{% if output_type == 'at::BFloat16' %}
#if !( \
#if defined(USE_ROCM) || !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
{% endif %}
Expand Down
48 changes: 24 additions & 24 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1651,9 +1651,9 @@ struct __align__(4) __nv_bfloat162 {
};
#endif

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
struct __align__(8) bfloat16_4 {
__host__ __device__ bfloat16_4() {}
__nv_bfloat162 vals[2];
Expand Down Expand Up @@ -1771,9 +1771,9 @@ static DEVICE_INLINE void quantize_float_store(
*output = input;
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE __nv_bfloat16 to_bfloat16(float v) {
return __float2bfloat16(v);
}
Expand Down Expand Up @@ -2347,9 +2347,9 @@ struct VecNT<1, PrimitiveType::FP> {
*reinterpret_cast<__half*>(output_ptr) = val;
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 1) {
Expand Down Expand Up @@ -2440,9 +2440,9 @@ struct VecNT<2, PrimitiveType::FP> {
}
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 2) {
Expand Down Expand Up @@ -2578,9 +2578,9 @@ struct VecNT<4, PrimitiveType::FP> {
}
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 4) {
Expand Down Expand Up @@ -2733,9 +2733,9 @@ struct VecNT<4, PrimitiveType::INT> {
}
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 4) {
Expand Down Expand Up @@ -2903,9 +2903,9 @@ struct VecNT<8, PrimitiveType::INT> {
}
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 8) {
Expand Down Expand Up @@ -3090,9 +3090,9 @@ struct VecNT<16, PrimitiveType::INT> {
}
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store(
at::BFloat16* output_ptr,
const int num_valid_outputs = 16) {
Expand Down

0 comments on commit 90ecc97

Please sign in to comment.