diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc index 7fde69d758ca9..98e3b878c9e0e 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc @@ -43,11 +43,15 @@ namespace cuda { ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \ ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double) +#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \ + ACTIVATION_GRAD_OP_HFD(name, ver, domain) \ + ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16) + ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain); -ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain); +ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain); diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index 164aba866722e..dd6a44b9e3b56 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad { #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \ template void Impl_##name(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count); -#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \ +#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16) #define ACTIVATION_GRAD_OP_NAME(name) \ BINARY_ELEMENTWISE_IMPL(name); \ - SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name) + SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name) ACTIVATION_GRAD_OPS() #undef ACTIVATION_GRAD_OP_NAME diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index ae4f48b6b49a2..eeaa51c4dc1d8 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad); @@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,