Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] UTs and code clean up #10511

Merged
merged 9 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 7 additions & 35 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,45 +92,17 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
return CUDA_CALL(cudaPeekAtLastError());
}

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <unsigned TPB>
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length,
int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias,
nv_bfloat162* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < input_length) {
const nv_bfloat162 x = input[idx];
const nv_bfloat162 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const nv_bfloat162 cdf = a + a * _Tanh(in * (c * in * in + b));
output[idx] = in * cdf;
}
}
#endif

template <>
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
if (0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
const nv_bfloat162 A2 = __floats2bfloat162_rn(A, A);
const nv_bfloat162 B2 = __floats2bfloat162_rn(B, B);
const nv_bfloat162 C2 = __floats2bfloat162_rn(C, C);
const nv_bfloat162* input2 = reinterpret_cast<const nv_bfloat162*>(input);
const nv_bfloat162* bias2 = reinterpret_cast<const nv_bfloat162*>(bias);
nv_bfloat162* output2 = reinterpret_cast<nv_bfloat162*>(output);
FastGeluKernel2<blockSize>
<<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
} else {
#endif
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
}
#endif

// remove nv_bfloat162 implementation for now to fix build issue
// we can decide whether to add it back if there's perf concern
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);

return CUDA_CALL(cudaPeekAtLastError());
}

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,6 @@ __device__ __inline__ BFloat16 _Log(BFloat16 a) { return logf(static_cast<float>
template <>
__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast<float>(a)); }

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__device__ __inline__ nv_bfloat162 _Tanh(nv_bfloat162 a) {
float2 tmp = (__bfloat1622float2(a));
tmp.x = tanhf(tmp.x);
tmp.y = tanhf(tmp.y);
return __float22bfloat162_rn(tmp);
}
#endif

template <>
__device__ __inline__ BFloat16 _Normcdf(BFloat16 a) { return normcdff(static_cast<float>(a)); }

Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
IAllocatorUniquePtr<float> temp_X;
cudnnDataType_t cudnn_type_X = CudnnTensor::GetDataType<CudaT>();

if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) {
// Reducesum with BFP16 is not supported by cudnn, so convert input to fp32 then call cudnn
if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) ||
(ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same<T, BFloat16>::value)) {
// ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn
temp_X = cuda_ep.GetScratchBuffer<float>(input_count);
cudnn_type_X = CUDNN_DATA_FLOAT;
Expand Down Expand Up @@ -652,11 +654,22 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.template MutableData<T>(), input.template Data<T>(), input_count * sizeof(T), cudaMemcpyDeviceToDevice, stream));
}
} else {
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<CudaT*>(output.template MutableData<T>())));
if (temp_X) {
auto temp_output = cuda_ep.GetScratchBuffer<float>(output_count);
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_output.get()));

Impl_Cast<float, CudaT>(stream, temp_output.get(), reinterpret_cast<CudaT*>(output.template MutableData<T>()), output_count);
} else {
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<CudaT*>(output.template MutableData<T>())));
}
}
}
} else {
Expand Down
26 changes: 20 additions & 6 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,10 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
IAllocatorUniquePtr<float> temp_X;
miopenDataType_t miopen_type_X = MiopenTensor::GetDataType<HipT>();

if (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) {
// unlike bfp16 not supported in cudnn, miopen call for bfp16 succeeded below, however, UT shows data error
// so for now, follow the same logic in cudnn and convert input to fp32 then call miopen
if ((ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) ||
(ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_NO_INDICES && std::is_same<T, BFloat16>::value)) {
// ArgMax/ArgMin with FP16 are not supported by miopen, so convert input to fp32 then call miopen
temp_X = rocm_ep.GetScratchBuffer<float>(input_count);
miopen_type_X = miopenFloat;
Expand Down Expand Up @@ -651,11 +654,22 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.template MutableData<T>(), input.template Data<T>(), input_count * sizeof(T), hipMemcpyDeviceToDevice, stream));
}
} else {
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const HipT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<HipT*>(output.template MutableData<T>())));
if (temp_X) {
auto temp_output = rocm_ep.GetScratchBuffer<float>(output_count);
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_output.get()));

Impl_Cast<float, HipT>(stream, temp_output.get(), reinterpret_cast<HipT*>(output.template MutableData<T>()), output_count);
} else {
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const HipT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<HipT*>(output.template MutableData<T>())));
}
}
}
} else {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/test/contrib_ops/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ TEST(BiasGeluTest, Two_One_Dim_fp16) {
}
#endif

// failed test for CUDA (therefore ROCM as well) to be investigated
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(BiasGeluTest, DISABLED_Two_One_Dim_bfloat16) {
TEST(BiasGeluTest, Two_One_Dim_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,9 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
}


// failed with device error, disabled for now
// CUDA only, ROCM has not been supported yet
#ifdef USE_CUDA
TEST(FastGeluTest, DISABLED_FastGeluWithBias_BFloat16) {
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,31 @@ TEST(ReductionOpTest, ReduceSumHalfHalf) {
test.Run();
}

TEST(ReductionOpTest, ReduceSumHalfHalf_2) {
OpTester test("ReduceSum");
test.AddAttribute("keepdims", (int64_t)0);
test.AddAttribute("axes", std::vector<int64_t>{0, 2});

std::vector<float> data = {1.0f, 2.0f,
3.0f, 4.0f,

5.0f, 6.0f,
7.0f, 8.0f,

9.0f, 10.0f,
11.0f, 12.0f};
std::vector<MLFloat16> data_half(12);
ConvertFloatToMLFloat16(data.data(), data_half.data(), 12);

std::vector<float> result = {33.0f, 45.0f};
std::vector<MLFloat16> result_half(2);
ConvertFloatToMLFloat16(result.data(), result_half.data(), 2);

test.AddInput<MLFloat16>("data", {3, 2, 2}, data_half);
test.AddOutput<MLFloat16>("reduced", {2}, result_half);
test.Run();
}

void test_half_reduce_sum(
int64_t m, int64_t n) {
OpTester test("ReduceSum");
Expand Down Expand Up @@ -1509,6 +1534,27 @@ TEST(ReductionOpTest, ReduceSumBFloat16) {
}
#endif

// on CUDA - this UT, with axes {0,2}, will go thru cudnn lib only if ATenOp is not initialized
// on ROCM - miopen call succeeded, but results in data error, thus follow the same logic done in cudnn for now
// TODO - try ROCm 4.5.2 and/or double check the source code on BFloat16 support
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(ReductionOpTest, ReduceSumBFloat16_2) {
OpTester test("ReduceSum", 14);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<BFloat16>("data", {3, 2, 2},
MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}));
test.AddInput<int64_t>("axes", {2}, std::vector<int64_t>{0, 2});
test.AddOutput<BFloat16>("reduced", {2}, MakeBFloat16({33.0f, 45.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

TEST(ReductionOpTest, ReduceSum_apex_reduction) {
OpTester test("ReduceSum");
test.AddAttribute("keepdims", (int64_t)0);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ struct TensorCheck<BFloat16> {
/// XXX: May need to adjust threshold as BFloat is coarse
float threshold = 0.001f;
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.008f;
threshold = 0.05f; // expect at least 95% close
#endif
for (int i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ struct MixedPrecisionScaleInputOutput {

input1_bf16.resize(input1.size());
output1_bf16.resize(output1.size());
std::vector<BFloat16> input1_bf16 = FloatsToBFloat16s(input1);
std::vector<BFloat16> output1_bf16 = FloatsToBFloat16s(output1);
input1_bf16 = FloatsToBFloat16s(input1);
output1_bf16 = FloatsToBFloat16s(output1);

input2_bf16.resize(input2.size());
output2_bf16.resize(output2.size());
std::vector<BFloat16> input2_bf16 = FloatsToBFloat16s(input2);
std::vector<BFloat16> output2_bf16 = FloatsToBFloat16s(output2);
input2_bf16 = FloatsToBFloat16s(input2);
output2_bf16 = FloatsToBFloat16s(output2);
}

// Fp32 Inputs/Output
Expand Down Expand Up @@ -172,8 +172,7 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// failed with data error, disabled for now
TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_float_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -197,7 +196,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_float) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -221,7 +220,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_half_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -245,7 +244,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_half) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_half) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand Down
21 changes: 0 additions & 21 deletions orttraining/orttraining/training_ops/rocm/math/softmax_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,6 @@ Status SoftMaxGradComputeHelper(
return Status::OK();
}

#define SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxGradComputeHelper<BFloat16, is_log_softmax>(hipStream_t stream, const BFloat16* dY, \
const TensorShape& input_shape, const BFloat16* Y, \
BFloat16* dX, miopenHandle_t, int64_t axis) { \
typedef typename ToHipType<BFloat16>::MappedType HipT; \
const int64_t normalized_axis = HandleNegativeAxis(axis, input_shape.NumDimensions()); \
int64_t N = input_shape.SizeToDimension(normalized_axis); \
int64_t D = input_shape.SizeFromDimension(normalized_axis); \
auto dY_data = reinterpret_cast<const HipT*>(dY); \
auto Y_data = reinterpret_cast<const HipT*>(Y); \
auto dX_data = reinterpret_cast<HipT*>(dX); \
dispatch_softmax_backward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>( \
stream, dX_data, dY_data, Y_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), \
gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(false)

#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SoftmaxGrad, \
Expand Down