diff --git a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc index f2b515a01741f..8cff32aa23fe7 100644 --- a/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/mixed_precision_scale_test.cc @@ -3,6 +3,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" namespace onnxruntime { namespace test { @@ -18,6 +19,16 @@ struct MixedPrecisionScaleInputOutput { output2_half.resize(output2.size()); ConvertFloatToMLFloat16(input2.data(), input2_half.data(), int(input2.size())); ConvertFloatToMLFloat16(output2.data(), output2_half.data(), int(output2.size())); + + input1_bf16.resize(input1.size()); + output1_bf16.resize(output1.size()); + std::vector input1_bf16 = FloatsToBFloat16s(input1); + std::vector output1_bf16 = FloatsToBFloat16s(output1); + + input2_bf16.resize(input2.size()); + output2_bf16.resize(output2.size()); + std::vector input2_bf16 = FloatsToBFloat16s(input2); + std::vector output2_bf16 = FloatsToBFloat16s(output2); } // Fp32 Inputs/Output @@ -32,6 +43,12 @@ struct MixedPrecisionScaleInputOutput { std::vector input2_half; std::vector output1_half; std::vector output2_half; + + // BF16 Inputs/Output + std::vector input1_bf16; + std::vector input2_bf16; + std::vector output1_bf16; + std::vector output2_bf16; }; TEST(CudaKernelTest, MixedPrecisionScaleF2F) { @@ -130,5 +147,127 @@ TEST(CudaKernelTest, MixedPrecisionScaleH2H) { test.Run(); } +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(CudaKernelTest, MixedPrecisionScale_bfloat16_bfloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + MixedPrecisionScaleInputOutput data; + OpTester test("MixedPrecisionScale", 1, onnxruntime::kMSDomain); + test.AddAttribute("to", int64_t(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)); + test.AddInput("scale", {1}, data.scale); + test.AddInput("input1", {3}, data.input1_bf16); + test.AddOutput("output1", {3}, data.output1_bf16); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_float_bfloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + MixedPrecisionScaleInputOutput data; + OpTester test("MixedPrecisionScale", 1, onnxruntime::kMSDomain); + test.AddAttribute("to", int64_t(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)); + test.AddInput("scale", {1}, data.scale); + test.AddInput("input1", {3}, data.input1); + test.AddOutput("output1", {3}, data.output1_bf16); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_float) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + MixedPrecisionScaleInputOutput data; + OpTester test("MixedPrecisionScale", 1, onnxruntime::kMSDomain); + test.AddAttribute("to", int64_t(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + test.AddInput("scale", {1}, data.scale); + test.AddInput("input1", {3}, data.input1_bf16); + test.AddOutput("output1", {3}, data.output1); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_half_bfloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + MixedPrecisionScaleInputOutput data; + OpTester test("MixedPrecisionScale", 1, onnxruntime::kMSDomain); + test.AddAttribute("to", int64_t(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)); + test.AddInput("scale", {1}, data.scale); + test.AddInput("input1", {3}, data.input1_half); + test.AddOutput("output1", {3}, data.output1_bf16); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(CudaKernelTest, DISABLED_MixedPrecisionScale_bfloat16_half) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + MixedPrecisionScaleInputOutput data; + OpTester test("MixedPrecisionScale", 1, onnxruntime::kMSDomain); + test.AddAttribute("to", int64_t(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + test.AddInput("scale", {1}, data.scale); + test.AddInput("input1", {3}, data.input1_bf16); + test.AddOutput("output1", {3}, data.output1_half); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + } // namespace test } // namespace onnxruntime \ No newline at end of file