Skip to content

Commit

Permalink
[ROCm] BFloat16 support (#10465)
Browse files Browse the repository at this point in the history
* bf16 support

* minor clean up

* UTs

* fix build

* UTs

* UTs

* merge commit 6b5504c

* minor

* ROCm code cleanup

* fix build

* fix build

* minor

Co-authored-by: Ethan Tao <[email protected]@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2022
1 parent c696da3 commit 435e14d
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 65 deletions.
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cudnn_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
return CUDNN_DATA_HALF;
}

template <>
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
return CUDNN_DATA_BFLOAT16;
}

template <>
cudnnDataType_t CudnnTensor::GetDataType<int8_t>() {
return CUDNN_DATA_INT8;
Expand Down
23 changes: 3 additions & 20 deletions onnxruntime/core/providers/cuda/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,7 @@ Status SoftMaxComputeHelper(
SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
SPECIALIZED_SOFTMAX_HELPER_IMPL(double)
SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16)

// cudnnSoftmaxForward/Backward doesn't support BFloat16.
#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxComputeHelper<BFloat16, is_log_softmax>(cudaStream_t stream, const BFloat16* X, \
const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \
typedef typename ToCudaType<BFloat16>::MappedType CudaT; \
int64_t N = input_shape.SizeToDimension(axis); \
int64_t D = input_shape.SizeFromDimension(axis); \
auto Y_data = reinterpret_cast<CudaT*>(Y); \
auto X_data = reinterpret_cast<const CudaT*>(X); \
dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>( \
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)
SPECIALIZED_SOFTMAX_HELPER_IMPL(BFloat16)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Expand Down Expand Up @@ -112,8 +95,8 @@ SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Softmax<T>);

template <typename T>
Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
template <typename T>
Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& input_shape{X->Shape()};
size_t rank = input_shape.NumDimensions();
Expand Down
20 changes: 1 addition & 19 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,7 @@ SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
// MIOpen double data type not supported
// SPECIALIZED_SOFTMAX_HELPER_IMPL(double)
SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16)

// cudnnSoftmaxForward/Backward doesn't support BFloat16.
// apply the same for miopen for now
#define SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(is_log_softmax) \
template <> \
Status SoftMaxComputeHelper<BFloat16, is_log_softmax>(hipStream_t stream, const BFloat16* X, \
const TensorShape& input_shape, BFloat16* Y, int64_t axis) { \
typedef typename ToHipType<BFloat16>::MappedType HipT; \
int64_t N = input_shape.SizeToDimension(axis); \
int64_t D = input_shape.SizeFromDimension(axis); \
auto Y_data = reinterpret_cast<HipT*>(Y); \
auto X_data = reinterpret_cast<const HipT*>(X); \
dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>( \
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N)); \
return Status::OK(); \
}

SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(true)
SPECIALIZED_SOFTMAX_HELPER_IMPL_BFloat16(false)
SPECIALIZED_SOFTMAX_HELPER_IMPL(BFloat16)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/rocm/miopen_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ miopenDataType_t MiopenTensor::GetDataType<half>() {
return miopenHalf;
}

template <>
miopenDataType_t MiopenTensor::GetDataType<BFloat16>() {
return miopenBFloat16;
}

template <>
miopenDataType_t MiopenTensor::GetDataType<int32_t>() {
return miopenInt32;
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/contrib_ops/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,78 @@ TEST(BiasGeluTest, Two_One_Dim) {
RunBiasGeluTest(input_a_data, input_b_data, {2, 4}, {4});
}

#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(BiasGeluTest, Two_One_Dim_fp16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
OpTester tester("BiasGelu", 1, onnxruntime::kMSDomain);

std::vector<float> A = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};

std::vector<float> B = {
-0.5f, 0.6f, 1.2f, 2.1f};

std::vector<float> Y = ComputeGeluWithErf(Add_Simple(A, B));

std::vector<MLFloat16> f_A(8);
std::vector<MLFloat16> f_B(4);
std::vector<MLFloat16> f_Y(8);
ConvertFloatToMLFloat16(A.data(), f_A.data(), 8);
ConvertFloatToMLFloat16(B.data(), f_B.data(), 4);
ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 8);

tester.AddInput<MLFloat16>("A", {2, 4}, f_A);
tester.AddInput<MLFloat16>("B", {4}, f_B);
tester.AddOutput<MLFloat16>("Y", {2, 4}, f_Y);
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: fp16 is not supported
}
#endif

// failed test for CUDA (therefore ROCM as well) to be investigated
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(BiasGeluTest, DISABLED_Two_One_Dim_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
}
#endif
OpTester tester("BiasGelu", 1, onnxruntime::kMSDomain);

std::vector<float> A = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};

std::vector<float> B = {
-0.5f, 0.6f, 1.2f, 2.1f};

std::vector<float> Y = ComputeGeluWithErf(Add_Simple(A, B));

std::vector<BFloat16> f_A = FloatsToBFloat16s(A);
std::vector<BFloat16> f_B = FloatsToBFloat16s(B);
std::vector<BFloat16> f_Y = FloatsToBFloat16s(Y);

tester.AddInput<BFloat16>("A", {2, 4}, f_A);
tester.AddInput<BFloat16>("B", {4}, f_B);
tester.AddOutput<BFloat16>("Y", {2, 4}, f_Y);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

TEST(MathOpTest, ComplexMul) {
if (DefaultCudaExecutionProvider() == nullptr) return;

Expand Down
47 changes: 47 additions & 0 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,52 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
}


// failed with device error, disabled for now
// CUDA only, ROCM has not been supported yet
#ifdef USE_CUDA
TEST(FastGeluTest, DISABLED_FastGeluWithBias_BFloat16) {
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
}
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;

std::vector<float> X = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};

std::vector<float> B = {
-0.5f, 0.6f, 1.2f, 2.1f};

std::vector<float> Y = {
0.1851806640625f, 0.054046630859375f, 1.0615234375f, 3.095703125f,
0, 0.63037109375f, 1.3984375f, 1.3984375f};

std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;

std::vector<BFloat16> f_X = FloatsToBFloat16s(X);
std::vector<BFloat16> f_B = FloatsToBFloat16s(B);
std::vector<BFloat16> f_Y = FloatsToBFloat16s(Y);

tester.AddInput<BFloat16>("X", input_dims, f_X);
tester.AddInput<BFloat16>("bias", bias_dims, f_B);
tester.AddOutput<BFloat16>("Y", output_dims, f_Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif



} // namespace test
} // namespace onnxruntime
95 changes: 95 additions & 0 deletions onnxruntime/test/contrib_ops/fused_matmul_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -269,6 +270,100 @@ TEST(FusedMatMulOpTest, FloatTypeTransposeBatch) {
RunFusedMatMulTest<float>("FusedMatMul", 1, true, true, true, true);
}

#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FusedMatMulOpTest, Float16_NoTranspose) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
std::vector<float> common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
for (auto t : GenerateSimpleTestCases<float>()) {

OpTester test("FusedMatMul", 1, onnxruntime::kMSDomain);

std::vector<int64_t> input0_dims(t.input0_dims);
std::vector<float> input0_vals;
ProcessInputs(t.input0_dims, common_input_vals, false, false, input0_dims, input0_vals);

std::vector<int64_t> input1_dims(t.input1_dims);
std::vector<float> input1_vals;
ProcessInputs(t.input1_dims, common_input_vals, false, false, input1_dims, input1_vals);

std::vector<MLFloat16> f_A(input0_vals.size());
std::vector<MLFloat16> f_B(input1_vals.size());
std::vector<MLFloat16> f_Y(t.expected_vals.size());
ConvertFloatToMLFloat16(input0_vals.data(), f_A.data(), (int)input0_vals.size());
ConvertFloatToMLFloat16(input1_vals.data(), f_B.data(), (int)input1_vals.size());
ConvertFloatToMLFloat16(t.expected_vals.data(), f_Y.data(), (int)t.expected_vals.size());

test.AddInput<MLFloat16>("A", input0_dims, f_A);
test.AddInput<MLFloat16>("B", input1_dims, f_B, false);

test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
test.AddAttribute("transBatchA", (int64_t)0);
test.AddAttribute("transBatchB", (int64_t)0);
test.AddAttribute("alpha", 1.0f);

test.AddOutput<MLFloat16>("Y", t.expected_dims, f_Y);

// Disable TensorRT because of unsupported data type
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
}
#endif

#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FusedMatMulOpTest, BFloat16_NoTranspose) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
std::vector<float> common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
for (auto t : GenerateSimpleTestCases<float>()) {

OpTester test("FusedMatMul", 1, onnxruntime::kMSDomain);

std::vector<int64_t> input0_dims(t.input0_dims);
std::vector<float> input0_vals;
ProcessInputs(t.input0_dims, common_input_vals, false, false, input0_dims, input0_vals);

std::vector<int64_t> input1_dims(t.input1_dims);
std::vector<float> input1_vals;
ProcessInputs(t.input1_dims, common_input_vals, false, false, input1_dims, input1_vals);

std::vector<BFloat16> f_A = FloatsToBFloat16s(input0_vals);
std::vector<BFloat16> f_B = FloatsToBFloat16s(input1_vals);
std::vector<BFloat16> f_Y = FloatsToBFloat16s(t.expected_vals);

test.AddInput<BFloat16>("A", input0_dims, f_A);
test.AddInput<BFloat16>("B", input1_dims, f_B, false);

test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
test.AddAttribute("transBatchA", (int64_t)0);
test.AddAttribute("transBatchB", (int64_t)0);
test.AddAttribute("alpha", 1.0f);

test.AddOutput<BFloat16>("Y", t.expected_dims, f_Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
#endif

} // namespace transpose_matmul
} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit 435e14d

Please sign in to comment.