Skip to content

Commit

Permalink
[ROCm] UTs and code clean up (#10511)
Browse files Browse the repository at this point in the history
* Fix UT

* UT

* UTs

* enable ROCm UT

* fix build attempt

* minor

* fix UT

* fix UT

* fix UTs

Co-authored-by: Ethan Tao <[email protected]@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <[email protected]>
  • Loading branch information
3 people authored Feb 11, 2022
1 parent 2002a96 commit 4e2a974
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 93 deletions.
42 changes: 7 additions & 35 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,45 +92,17 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
return CUDA_CALL(cudaPeekAtLastError());
}

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <unsigned TPB>
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length,
int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias,
nv_bfloat162* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < input_length) {
const nv_bfloat162 x = input[idx];
const nv_bfloat162 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const nv_bfloat162 cdf = a + a * _Tanh(in * (c * in * in + b));
output[idx] = in * cdf;
}
}
#endif

template <>
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
if (0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
const nv_bfloat162 A2 = __floats2bfloat162_rn(A, A);
const nv_bfloat162 B2 = __floats2bfloat162_rn(B, B);
const nv_bfloat162 C2 = __floats2bfloat162_rn(C, C);
const nv_bfloat162* input2 = reinterpret_cast<const nv_bfloat162*>(input);
const nv_bfloat162* bias2 = reinterpret_cast<const nv_bfloat162*>(bias);
nv_bfloat162* output2 = reinterpret_cast<nv_bfloat162*>(output);
FastGeluKernel2<blockSize>
<<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
} else {
#endif
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
}
#endif

// remove nv_bfloat162 implementation for now to fix build issue
// we can decide whether to add it back if there's perf concern
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);

return CUDA_CALL(cudaPeekAtLastError());
}

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,6 @@ __device__ __inline__ BFloat16 _Log(BFloat16 a) { return logf(static_cast<float>
template <>
__device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast<float>(a)); }

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__device__ __inline__ nv_bfloat162 _Tanh(nv_bfloat162 a) {
float2 tmp = (__bfloat1622float2(a));
tmp.x = tanhf(tmp.x);
tmp.y = tanhf(tmp.y);
return __float22bfloat162_rn(tmp);
}
#endif

template <>
__device__ __inline__ BFloat16 _Normcdf(BFloat16 a) { return normcdff(static_cast<float>(a)); }

Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
IAllocatorUniquePtr<float> temp_X;
cudnnDataType_t cudnn_type_X = CudnnTensor::GetDataType<CudaT>();

if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) {
// Reducesum with BFP16 is not supported by cudnn, so convert input to fp32 then call cudnn
if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) ||
(ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same<T, BFloat16>::value)) {
// ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn
temp_X = cuda_ep.GetScratchBuffer<float>(input_count);
cudnn_type_X = CUDNN_DATA_FLOAT;
Expand Down Expand Up @@ -652,11 +654,22 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.template MutableData<T>(), input.template Data<T>(), input_count * sizeof(T), cudaMemcpyDeviceToDevice, stream));
}
} else {
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<CudaT*>(output.template MutableData<T>())));
if (temp_X) {
auto temp_output = cuda_ep.GetScratchBuffer<float>(output_count);
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_output.get()));

Impl_Cast<float, CudaT>(stream, temp_output.get(), reinterpret_cast<CudaT*>(output.template MutableData<T>()), output_count);
} else {
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
workspace_cuda.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<CudaT*>(output.template MutableData<T>())));
}
}
}
} else {
Expand Down
26 changes: 20 additions & 6 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,10 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
IAllocatorUniquePtr<float> temp_X;
miopenDataType_t miopen_type_X = MiopenTensor::GetDataType<HipT>();

if (ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) {
// unlike bfp16 not supported in cudnn, miopen call for bfp16 succeeded below, however, UT shows data error
// so for now, follow the same logic in cudnn and convert input to fp32 then call miopen
if ((ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) ||
(ReduceTensorIndices == MIOPEN_REDUCE_TENSOR_NO_INDICES && std::is_same<T, BFloat16>::value)) {
// ArgMax/ArgMin with FP16 are not supported by miopen, so convert input to fp32 then call miopen
temp_X = rocm_ep.GetScratchBuffer<float>(input_count);
miopen_type_X = miopenFloat;
Expand Down Expand Up @@ -651,11 +654,22 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.template MutableData<T>(), input.template Data<T>(), input_count * sizeof(T), hipMemcpyDeviceToDevice, stream));
}
} else {
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const HipT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<HipT*>(output.template MutableData<T>())));
if (temp_X) {
auto temp_output = rocm_ep.GetScratchBuffer<float>(output_count);
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, temp_X.get(),
&zero, output_tensor, temp_output.get()));

Impl_Cast<float, HipT>(stream, temp_output.get(), reinterpret_cast<HipT*>(output.template MutableData<T>()), output_count);
} else {
MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(
rocm_ep.PerThreadMiopenHandle(), reduce_desc, indices_rocm.get(), indices_bytes,
workspace_rocm.get(), workspace_bytes,
&one, input_tensor, reinterpret_cast<const HipT*>(input.template Data<T>()),
&zero, output_tensor, reinterpret_cast<HipT*>(output.template MutableData<T>())));
}
}
}
} else {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/test/contrib_ops/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ TEST(BiasGeluTest, Two_One_Dim_fp16) {
}
#endif

// failed test for CUDA (therefore ROCM as well) to be investigated
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(BiasGeluTest, DISABLED_Two_One_Dim_bfloat16) {
TEST(BiasGeluTest, Two_One_Dim_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,9 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
}


// failed with device error, disabled for now
// CUDA only, ROCM has not been supported yet
#ifdef USE_CUDA
TEST(FastGeluTest, DISABLED_FastGeluWithBias_BFloat16) {
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,31 @@ TEST(ReductionOpTest, ReduceSumHalfHalf) {
test.Run();
}

TEST(ReductionOpTest, ReduceSumHalfHalf_2) {
OpTester test("ReduceSum");
test.AddAttribute("keepdims", (int64_t)0);
test.AddAttribute("axes", std::vector<int64_t>{0, 2});

std::vector<float> data = {1.0f, 2.0f,
3.0f, 4.0f,

5.0f, 6.0f,
7.0f, 8.0f,

9.0f, 10.0f,
11.0f, 12.0f};
std::vector<MLFloat16> data_half(12);
ConvertFloatToMLFloat16(data.data(), data_half.data(), 12);

std::vector<float> result = {33.0f, 45.0f};
std::vector<MLFloat16> result_half(2);
ConvertFloatToMLFloat16(result.data(), result_half.data(), 2);

test.AddInput<MLFloat16>("data", {3, 2, 2}, data_half);
test.AddOutput<MLFloat16>("reduced", {2}, result_half);
test.Run();
}

void test_half_reduce_sum(
int64_t m, int64_t n) {
OpTester test("ReduceSum");
Expand Down Expand Up @@ -1509,6 +1534,27 @@ TEST(ReductionOpTest, ReduceSumBFloat16) {
}
#endif

// on CUDA - this UT, with axes {0,2}, will go thru cudnn lib only if ATenOp is not initialized
// on ROCM - miopen call succeeded, but results in data error, thus follow the same logic done in cudnn for now
// TODO - try ROCm 4.5.2 and/or double check the source code on BFloat16 support
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(ReductionOpTest, ReduceSumBFloat16_2) {
OpTester test("ReduceSum", 14);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<BFloat16>("data", {3, 2, 2},
MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}));
test.AddInput<int64_t>("axes", {2}, std::vector<int64_t>{0, 2});
test.AddOutput<BFloat16>("reduced", {2}, MakeBFloat16({33.0f, 45.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

TEST(ReductionOpTest, ReduceSum_apex_reduction) {
OpTester test("ReduceSum");
test.AddAttribute("keepdims", (int64_t)0);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ struct TensorCheck<BFloat16> {
/// XXX: May need to adjust threshold as BFloat is coarse
float threshold = 0.001f;
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.008f;
threshold = 0.05f; // expect at least 95% close
#endif
for (int i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ struct MixedPrecisionScaleInputOutput {

input1_bf16.resize(input1.size());
output1_bf16.resize(output1.size());
std::vector<BFloat16> input1_bf16 = FloatsToBFloat16s(input1);
std::vector<BFloat16> output1_bf16 = FloatsToBFloat16s(output1);
input1_bf16 = FloatsToBFloat16s(input1);
output1_bf16 = FloatsToBFloat16s(output1);

input2_bf16.resize(input2.size());
output2_bf16.resize(output2.size());
std::vector<BFloat16> input2_bf16 = FloatsToBFloat16s(input2);
std::vector<BFloat16> output2_bf16 = FloatsToBFloat16s(output2);
input2_bf16 = FloatsToBFloat16s(input2);
output2_bf16 = FloatsToBFloat16s(output2);
}

// Fp32 Inputs/Output
Expand Down Expand Up @@ -172,8 +172,7 @@ TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// failed with data error, disabled for now
TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_float_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -197,7 +196,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_float) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -221,7 +220,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
TEST(CudaKernelTest, MixedPrecisionScale_half_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -245,7 +244,7 @@ TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_half) {
TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_half) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand Down
21 changes: 0 additions & 21 deletions orttraining/orttraining/training_ops/rocm/math/softmax_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,6 @@ Status SoftMaxGradComputeHelper(
return Status::OK();
}

#define SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxGradComputeHelper<BFloat16, is_log_softmax>(hipStream_t stream, const BFloat16* dY, \
const TensorShape& input_shape, const BFloat16* Y, \
BFloat16* dX, miopenHandle_t, int64_t axis) { \
typedef typename ToHipType<BFloat16>::MappedType HipT; \
const int64_t normalized_axis = HandleNegativeAxis(axis, input_shape.NumDimensions()); \
int64_t N = input_shape.SizeToDimension(normalized_axis); \
int64_t D = input_shape.SizeFromDimension(normalized_axis); \
auto dY_data = reinterpret_cast<const HipT*>(dY); \
auto Y_data = reinterpret_cast<const HipT*>(Y); \
auto dX_data = reinterpret_cast<HipT*>(dX); \
dispatch_softmax_backward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>( \
stream, dX_data, dY_data, Y_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), \
gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAXGRAD_HELPER_IMPL_BFloat16(false)

#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SoftmaxGrad, \
Expand Down

0 comments on commit 4e2a974

Please sign in to comment.