Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA BFloat16 Refactor #10085

Merged
merged 10 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
87 changes: 48 additions & 39 deletions include/onnxruntime/core/framework/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
#pragma once

#include "endian.h"
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include "cuda_bf16.h"
#endif

namespace onnxruntime
{
#include "core/common/common.h"

namespace onnxruntime {

#if defined(__CUDACC__) || defined(__HIPCC__)
#define ORT_HOST_DEVICE __host__ __device__
#else
#define ORT_HOST_DEVICE
#endif

// MLFloat16
struct MLFloat16 {
Expand All @@ -17,53 +27,64 @@ struct MLFloat16 {

float ToFloat() const;

operator float() const {
return ToFloat();
}
operator float() const { return ToFloat(); }
};

inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
return left.val == right.val;
}
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }

inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
return left.val != right.val;
}

inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
return left.val < right.val;
}

//BFloat16
// BFloat16
struct BFloat16 {
uint16_t val{0};
explicit BFloat16() = default;
explicit BFloat16(uint16_t v) : val(v) {}
explicit BFloat16(float v) {
#if defined(USE_ROCM)
ORT_HOST_DEVICE BFloat16() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why is the line above specific to ROCM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw PyTorch does this way for all default constructors so I followed the same way. Maybe hipcc requires this? But I didn't find out any documentation to support this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's leave it as it is and will re-visit when supporting BF16 on AMD GPU.

#else
BFloat16() = default;
#endif

struct FromBitsT {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to introduce struct FromBitsT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This idea is from PyTorch. It means if it's initialized from FromBitsT, then the bits will assign to val directly (the real value of BFloat16 instance is not equal to bits), but if not, for example, BFloat16(unsigned short value), it will initialize a BFloat16 == value (but the val member in the object is not equal to value). This is critical for some casting case, for example, BFloat16(1), which casts int to BFloat16, if we don't have this FromBitsT, the complier will report error saying ambiguous constructors, it doesn't know which to choose from BFloat16(unsigned short) or BFloat16(float). Even we don't have such ambigous problem, if compiler chooses BFloat(unsigned short) to do the job but assign the 1 to val memer directly, we would get a wrong BFloat16 instance. Actually our MLFloat16 also has such bug, but we don't have code such as MLFloat16(1) so we haven't encountered the compiler error for now.

static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits){};

inline ORT_HOST_DEVICE BFloat16(float v) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
val = __bfloat16_as_ushort(__float2bfloat16(v));
#else
ORT_IF_CONSTEXPR(endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
} else {
}
else {
std::memcpy(&val, &v, sizeof(uint16_t));
}
#endif
}

float ToFloat() const {
inline ORT_HOST_DEVICE float ToFloat() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
#else
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
ORT_IF_CONSTEXPR(endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
}
else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
#endif
}

operator float() const {
return ToFloat();
}
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
#endif
};

inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
Expand All @@ -82,16 +103,4 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
}
}

inline bool operator==(const BFloat16& left, const BFloat16& right) {
return left.val == right.val;
}

inline bool operator!=(const BFloat16& left, const BFloat16& right) {
return left.val != right.val;
}

inline bool operator<(const BFloat16& left, const BFloat16& right) {
return left.val < right.val;
}

}
} // namespace onnxruntime
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(BFloat16)
#endif

using namespace ONNX_NAMESPACE;

Expand Down
24 changes: 14 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,24 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <unsigned TPB>
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c,
int input_length, int bias_length,
const nv_bfloat162* input, const nv_bfloat162* bias, nv_bfloat162* output) {
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length,
int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias,
nv_bfloat162* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

if (idx < input_length) {
const nv_bfloat162 x = input[idx];
const nv_bfloat162 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const nv_bfloat162 cdf = a + a * _Tanh(in * (c * in * in + b));
output[idx] = in * cdf;
}
}
#endif

template <>
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const nv_bfloat16* input, const nv_bfloat16* bias, nv_bfloat16* output, bool /*use_half2*/) {
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
if (0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
Expand All @@ -120,15 +121,18 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
const nv_bfloat162* input2 = reinterpret_cast<const nv_bfloat162*>(input);
const nv_bfloat162* bias2 = reinterpret_cast<const nv_bfloat162*>(bias);
nv_bfloat162* output2 = reinterpret_cast<nv_bfloat162*>(output);
FastGeluKernel2<blockSize><<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
FastGeluKernel2<blockSize>
<<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
} else {
#endif
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<nv_bfloat16, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
}

#endif
return CUDA_CALL(cudaPeekAtLastError());
}
#endif

} // namespace cuda
} // namespace contrib
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
Expand Down Expand Up @@ -96,13 +97,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization);
#endif

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand All @@ -122,6 +120,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
Expand Down Expand Up @@ -190,14 +189,11 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu)>,

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
};

Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float, float)
REGISTER_KERNEL_TYPED(double, double)
REGISTER_KERNEL_TYPED(MLFloat16, float)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(BFloat16, float)
#endif

template <typename T, typename U, bool simplified>
LayerNorm<T, U, simplified>::LayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/contrib_ops/cuda/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,8 @@ LAYERNORM_LINEAR_IMPL(half, float, false)
LAYERNORM_LINEAR_IMPL(double, double, false)

//LAYERNORM_LINEAR_IMPL(half, half)
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
LAYERNORM_LINEAR_IMPL(nv_bfloat16, float, true)
LAYERNORM_LINEAR_IMPL(nv_bfloat16, float, false)
#endif
LAYERNORM_LINEAR_IMPL(BFloat16, float, true)
LAYERNORM_LINEAR_IMPL(BFloat16, float, false)

} // namespace cuda
} // namespace contrib
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/math/bias_dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ ONNX_OPERATOR_KERNEL_EX(
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", ALL_IEEE_FLOAT_TENSOR_TYPES)
.TypeConstraint("T1", ALL_IEEE_FLOAT_TENSOR_TYPES)
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>())
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4),
Expand Down Expand Up @@ -96,7 +96,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const {
float ratio_data = default_ratio_;
auto ratio = context->Input<Tensor>(3);
if (ratio) {
utils::MLTypeCallDispatcher<ALL_IEEE_FLOAT_DATA_TYPES> t_disp(ratio->GetElementType());
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(ratio->GetElementType());
t_disp.Invoke<GetRatioDataImpl>(ratio, ratio_data);
}

Expand All @@ -117,7 +117,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const {
const fast_divmod fdm_dim(gsl::narrow_cast<int>(dim));
PhiloxGenerator& generator = generator_ ? *generator_ : PhiloxGenerator::Default();

utils::MLTypeCallDispatcher<ALL_IEEE_FLOAT_DATA_TYPES> t_disp(X->GetElementType());
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(X->GetElementType());
return t_disp.InvokeRet<Status, BiasDropoutComputeImpl>(
GetDeviceProp(), Stream(), N, fdm_dim, ratio_data, generator, *X, *bias, residual, *Y, mask_data, has_same_shape_bias);
}
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/contrib_ops/cuda/math/bias_dropout_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ void BiasDropoutKernelImpl(
SPECIALIZED_BIAS_DROPOUT_IMPL(float)
SPECIALIZED_BIAS_DROPOUT_IMPL(double)
SPECIALIZED_BIAS_DROPOUT_IMPL(half)
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
SPECIALIZED_BIAS_DROPOUT_IMPL(nv_bfloat16)
#endif
SPECIALIZED_BIAS_DROPOUT_IMPL(BFloat16)

} // namespace cuda
} // namespace contrib {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ namespace cuda {
#define CONTRIB_BINARY_OP_HFD(name, ver) \
CONTRIB_BINARY_OP_TYPED(name, ver, MLFloat16) \
CONTRIB_BINARY_OP_TYPED(name, ver, float) \
CONTRIB_BINARY_OP_TYPED(name, ver, double)
CONTRIB_BINARY_OP_TYPED(name, ver, double) \
CONTRIB_BINARY_OP_TYPED(name, ver, BFloat16)

CONTRIB_BINARY_OP_HFD(BiasGelu, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ namespace cuda {
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

// create declarations for op and impl
#define CONTRIB_BINARY_OP_NAME_EXPR(name, expr) \
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/contrib_ops/cuda/math/fused_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ namespace cuda {
REGISTER_KERNEL_TYPED(TransposeMatMul, float)
REGISTER_KERNEL_TYPED(TransposeMatMul, double)
REGISTER_KERNEL_TYPED(TransposeMatMul, MLFloat16)
REGISTER_KERNEL_TYPED(TransposeMatMul, BFloat16)

REGISTER_KERNEL_TYPED(FusedMatMul, float)
REGISTER_KERNEL_TYPED(FusedMatMul, double)
REGISTER_KERNEL_TYPED(FusedMatMul, MLFloat16)

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(TransposeMatMul, BFloat16)
REGISTER_KERNEL_TYPED(FusedMatMul, BFloat16)
#endif

} // namespace cuda
} // namespace contrib
Expand Down
18 changes: 6 additions & 12 deletions onnxruntime/contrib_ops/cuda/math/isfinite.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/math/isfinite.h"

#if CUDA_VERSION >= 11000
#include "cuda_bf16.h"
#endif

namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -54,22 +50,20 @@ __device__ __forceinline__ bool IsNaNScalar(const half value) {
#endif
}

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__device__ __forceinline__ bool IsFiniteScalar(const nv_bfloat16 value) {
return !__hisinf(value) && !__hisnan(value);
__device__ __forceinline__ bool IsFiniteScalar(const BFloat16 value) {
return isfinite(static_cast<float>(value));
}

template <>
__device__ __forceinline__ bool IsInfScalar(const nv_bfloat16 value) {
return __hisinf(value);
__device__ __forceinline__ bool IsInfScalar(const BFloat16 value) {
return isinf(static_cast<float>(value));
}

template <>
__device__ __forceinline__ bool IsNaNScalar(const nv_bfloat16 value) {
return __hisnan(value);
__device__ __forceinline__ bool IsNaNScalar(const BFloat16 value) {
return isnan(static_cast<float>(value));
}
#endif

} // namespace cuda
} // namespace onnxruntime
Loading