diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu index 61879024c26f8..e7cb6c7ec3c30 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu @@ -92,7 +92,7 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i return CUDA_CALL(cudaPeekAtLastError()); } -#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800 template __global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length, int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias, @@ -111,7 +111,7 @@ template <> bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; -#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800 if (0 == (bias_length & 1) && prop.major >= 7) { const int n = input_length / 2; const int gridSize = (n + blockSize - 1) / blockSize; @@ -128,7 +128,7 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i const int gridSize = (input_length + blockSize - 1) / blockSize; FastGeluKernel <<>>(A, B, C, input_length, bias_length, input, bias, output); -#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800 } #endif return CUDA_CALL(cudaPeekAtLastError()); diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index ea9ddf0450908..152f0d86e50af 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -351,7 +351,7 @@ __device__ __inline__ BFloat16 _Log(BFloat16 a) { return logf(static_cast template <> __device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast(a)); } -#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800 template <> __device__ __inline__ nv_bfloat162 _Tanh(nv_bfloat162 a) { float2 tmp = (__bfloat1622float2(a)); diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index ed1d792a5259e..0b2543ec153e1 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -136,6 +136,11 @@ cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_HALF; } +template <> +cudnnDataType_t CudnnTensor::GetDataType() { + return CUDNN_DATA_BFLOAT16; +} + template <> cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_INT8; diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index 0d41b83a48f07..b2a7d4dc6a51b 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -43,23 +43,7 @@ Status SoftMaxComputeHelper( SPECIALIZED_SOFTMAX_HELPER_IMPL(float) SPECIALIZED_SOFTMAX_HELPER_IMPL(double) SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16) - -#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \ - template <> \ - Status SoftMaxComputeHelper(cudaStream_t stream, const BFloat16* X, \ - const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \ - typedef typename ToCudaType::MappedType CudaT; \ - int64_t N = input_shape.SizeToDimension(axis); \ - int64_t D = input_shape.SizeFromDimension(axis); \ - auto Y_data = reinterpret_cast(Y); \ - auto X_data = reinterpret_cast(X); \ - dispatch_warpwise_softmax_forward, is_log_softmax>( \ - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); \ - return Status::OK(); \ - } - -SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true) -SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false) +SPECIALIZED_SOFTMAX_HELPER_IMPL(BFloat16) #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 6ea034530f228..8dbeb56467367 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -188,7 +188,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) { // failed with invalid device error, disabled for now // CUDA only, ROCM has not been supported yet #ifdef USE_CUDA -TEST(FastGeluTest, DISABLED_FastGeluWithBias_BFloat16) { +TEST(FastGeluTest, FastGeluWithBias_BFloat16) { int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; diff --git a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc index 8cff32aa23fe7..72fc564f88244 100644 --- a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc @@ -172,7 +172,7 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) { +TEST(CudaKernelTest, MixedPrecisionScale_float_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -196,7 +196,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) { +TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_float) { #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -220,7 +220,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) { +TEST(CudaKernelTest, MixedPrecisionScale_half_bfloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -244,7 +244,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_half) { +TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_half) { #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { diff --git a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc index bce245b882185..1a8af5045b50a 100644 --- a/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc @@ -62,27 +62,6 @@ Status SoftMaxGradComputeHelper( return Status::OK(); } -#define SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(is_log_softmax) \ - template <> \ - Status SoftMaxGradComputeHelper(cudaStream_t stream, const BFloat16* dY, \ - const TensorShape& input_shape, const BFloat16* Y, \ - BFloat16* dX, cudnnHandle_t, int64_t axis) { \ - typedef typename ToCudaType::MappedType CudaT; \ - const int64_t normalized_axis = HandleNegativeAxis(axis, input_shape.NumDimensions()); \ - int64_t N = input_shape.SizeToDimension(normalized_axis); \ - int64_t D = input_shape.SizeFromDimension(normalized_axis); \ - auto dY_data = reinterpret_cast(dY); \ - auto Y_data = reinterpret_cast(Y); \ - auto dX_data = reinterpret_cast(dX); \ - dispatch_softmax_backward, is_log_softmax>( \ - stream, dX_data, dY_data, Y_data, gsl::narrow_cast(D), gsl::narrow_cast(D), \ - gsl::narrow_cast(N)); \ - return Status::OK(); \ - } - -SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(true) -SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(false) - #define REGISTER_GRADIENT_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ SoftmaxGrad, \