Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Feb 7, 2022
1 parent bad1aa2 commit 6b5504c
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 47 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
return CUDA_CALL(cudaPeekAtLastError());
}

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800
template <unsigned TPB>
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length,
int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias,
Expand All @@ -111,7 +111,7 @@ template <>
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800
if (0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
Expand All @@ -128,7 +128,7 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800
}
#endif
return CUDA_CALL(cudaPeekAtLastError());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ __device__ __inline__ BFloat16 _Log(BFloat16 a) { return logf(static_cast<float>
template <>
__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast<float>(a)); }

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800
template <>
__device__ __inline__ nv_bfloat162 _Tanh(nv_bfloat162 a) {
float2 tmp = (__bfloat1622float2(a));
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cudnn_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
return CUDNN_DATA_HALF;
}

template <>
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
return CUDNN_DATA_BFLOAT16;
}

template <>
cudnnDataType_t CudnnTensor::GetDataType<int8_t>() {
return CUDNN_DATA_INT8;
Expand Down
18 changes: 1 addition & 17 deletions onnxruntime/core/providers/cuda/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,7 @@ Status SoftMaxComputeHelper(
SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
SPECIALIZED_SOFTMAX_HELPER_IMPL(double)
SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16)

#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxComputeHelper<BFloat16, is_log_softmax>(cudaStream_t stream, const BFloat16* X, \
const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \
typedef typename ToCudaType<BFloat16>::MappedType CudaT; \
int64_t N = input_shape.SizeToDimension(axis); \
int64_t D = input_shape.SizeFromDimension(axis); \
auto Y_data = reinterpret_cast<CudaT*>(Y); \
auto X_data = reinterpret_cast<const CudaT*>(X); \
dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>( \
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)
SPECIALIZED_SOFTMAX_HELPER_IMPL(BFloat16)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
// failed with invalid device error, disabled for now
// CUDA only, ROCM has not been supported yet
#ifdef USE_CUDA
TEST(FastGeluTest, DISABLED_FastGeluWithBias_BFloat16) {
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_float_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -196,7 +196,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_float) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -220,7 +220,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_half_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -244,7 +244,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_half) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_half) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand Down
21 changes: 0 additions & 21 deletions orttraining/orttraining/training_ops/cuda/math/softmax_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,6 @@ Status SoftMaxGradComputeHelper(
return Status::OK();
}

#define SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxGradComputeHelper<BFloat16, is_log_softmax>(cudaStream_t stream, const BFloat16* dY, \
const TensorShape& input_shape, const BFloat16* Y, \
BFloat16* dX, cudnnHandle_t, int64_t axis) { \
typedef typename ToCudaType<BFloat16>::MappedType CudaT; \
const int64_t normalized_axis = HandleNegativeAxis(axis, input_shape.NumDimensions()); \
int64_t N = input_shape.SizeToDimension(normalized_axis); \
int64_t D = input_shape.SizeFromDimension(normalized_axis); \
auto dY_data = reinterpret_cast<const CudaT*>(dY); \
auto Y_data = reinterpret_cast<const CudaT*>(Y); \
auto dX_data = reinterpret_cast<CudaT*>(dX); \
dispatch_softmax_backward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>( \
stream, dX_data, dY_data, Y_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), \
gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(false)

#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SoftmaxGrad, \
Expand Down

0 comments on commit 6b5504c

Please sign in to comment.