Skip to content

Commit

Permalink
bfloat16 support for quickgelugrad (microsoft#18336)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Registers BFloat16 datatype as valid input type for CUDA QuickGeluGrad
Kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

---------

Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent a671e9e commit 7bac1fc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ namespace cuda {
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double)

#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \
ACTIVATION_GRAD_OP_HFD(name, ver, domain) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16)

ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count);

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define ACTIVATION_GRAD_OP_NAME(name) \
BINARY_ELEMENTWISE_IMPL(name); \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name)

ACTIVATION_GRAD_OPS()
#undef ACTIVATION_GRAD_OP_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
Expand Down Expand Up @@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
Expand Down

0 comments on commit 7bac1fc

Please sign in to comment.