Skip to content

Commit

Permalink
CUDA BFloat16 Refactor (#10085)
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang authored Jan 14, 2022
1 parent e38e51e commit 44e2db9
Show file tree
Hide file tree
Showing 80 changed files with 707 additions and 764 deletions.
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
87 changes: 48 additions & 39 deletions include/onnxruntime/core/framework/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
#pragma once

#include "endian.h"
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include "cuda_bf16.h"
#endif

namespace onnxruntime
{
#include "core/common/common.h"

namespace onnxruntime {

#if defined(__CUDACC__) || defined(__HIPCC__)
#define ORT_HOST_DEVICE __host__ __device__
#else
#define ORT_HOST_DEVICE
#endif

// MLFloat16
struct MLFloat16 {
Expand All @@ -17,53 +27,64 @@ struct MLFloat16 {

float ToFloat() const;

operator float() const {
return ToFloat();
}
operator float() const { return ToFloat(); }
};

inline bool operator==(const MLFloat16& left, const MLFloat16& right) {
return left.val == right.val;
}
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }

inline bool operator!=(const MLFloat16& left, const MLFloat16& right) {
return left.val != right.val;
}

inline bool operator<(const MLFloat16& left, const MLFloat16& right) {
return left.val < right.val;
}

//BFloat16
// BFloat16
struct BFloat16 {
uint16_t val{0};
explicit BFloat16() = default;
explicit BFloat16(uint16_t v) : val(v) {}
explicit BFloat16(float v) {
#if defined(USE_ROCM)
ORT_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif

struct FromBitsT {};
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits){};

inline ORT_HOST_DEVICE BFloat16(float v) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
val = __bfloat16_as_ushort(__float2bfloat16(v));
#else
ORT_IF_CONSTEXPR(endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
} else {
}
else {
std::memcpy(&val, &v, sizeof(uint16_t));
}
#endif
}

float ToFloat() const {
inline ORT_HOST_DEVICE float ToFloat() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
#else
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
ORT_IF_CONSTEXPR(endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
}
else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
#endif
}

operator float() const {
return ToFloat();
}
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
#endif
};

inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
Expand All @@ -82,16 +103,4 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
}
}

inline bool operator==(const BFloat16& left, const BFloat16& right) {
return left.val == right.val;
}

inline bool operator!=(const BFloat16& left, const BFloat16& right) {
return left.val != right.val;
}

inline bool operator<(const BFloat16& left, const BFloat16& right) {
return left.val < right.val;
}

}
} // namespace onnxruntime
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(BFloat16)
#endif

using namespace ONNX_NAMESPACE;

Expand Down
24 changes: 14 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,24 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <unsigned TPB>
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c,
int input_length, int bias_length,
const nv_bfloat162* input, const nv_bfloat162* bias, nv_bfloat162* output) {
__global__ void FastGeluKernel2(const nv_bfloat162 a, const nv_bfloat162 b, const nv_bfloat162 c, int input_length,
int bias_length, const nv_bfloat162* input, const nv_bfloat162* bias,
nv_bfloat162* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

if (idx < input_length) {
const nv_bfloat162 x = input[idx];
const nv_bfloat162 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const nv_bfloat162 cdf = a + a * _Tanh(in * (c * in * in + b));
output[idx] = in * cdf;
}
}
#endif

template <>
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const nv_bfloat16* input, const nv_bfloat16* bias, nv_bfloat16* output, bool /*use_half2*/) {
bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
if (0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
Expand All @@ -120,15 +121,18 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i
const nv_bfloat162* input2 = reinterpret_cast<const nv_bfloat162*>(input);
const nv_bfloat162* bias2 = reinterpret_cast<const nv_bfloat162*>(bias);
nv_bfloat162* output2 = reinterpret_cast<nv_bfloat162*>(output);
FastGeluKernel2<blockSize><<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
FastGeluKernel2<blockSize>
<<<gridSize, blockSize, 0, stream>>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2);
} else {
#endif
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<nv_bfloat16, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
FastGeluKernel<BFloat16, blockSize>
<<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length, input, bias, output);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
}

#endif
return CUDA_CALL(cudaPeekAtLastError());
}
#endif

} // namespace cuda
} // namespace contrib
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
Expand Down Expand Up @@ -86,13 +87,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization);
#endif

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand All @@ -112,6 +110,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
Expand Down Expand Up @@ -180,14 +179,11 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu)>,

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
};

Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float, float)
REGISTER_KERNEL_TYPED(double, double)
REGISTER_KERNEL_TYPED(MLFloat16, float)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(BFloat16, float)
#endif

template <typename T, typename U, bool simplified>
LayerNorm<T, U, simplified>::LayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/contrib_ops/cuda/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,8 @@ LAYERNORM_LINEAR_IMPL(half, float, false)
LAYERNORM_LINEAR_IMPL(double, double, false)

//LAYERNORM_LINEAR_IMPL(half, half)
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
LAYERNORM_LINEAR_IMPL(nv_bfloat16, float, true)
LAYERNORM_LINEAR_IMPL(nv_bfloat16, float, false)
#endif
LAYERNORM_LINEAR_IMPL(BFloat16, float, true)
LAYERNORM_LINEAR_IMPL(BFloat16, float, false)

} // namespace cuda
} // namespace contrib
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/math/bias_dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ ONNX_OPERATOR_KERNEL_EX(
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", ALL_IEEE_FLOAT_TENSOR_TYPES)
.TypeConstraint("T1", ALL_IEEE_FLOAT_TENSOR_TYPES)
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>())
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4),
Expand Down Expand Up @@ -96,7 +96,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const {
float ratio_data = default_ratio_;
auto ratio = context->Input<Tensor>(3);
if (ratio) {
utils::MLTypeCallDispatcher<ALL_IEEE_FLOAT_DATA_TYPES> t_disp(ratio->GetElementType());
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(ratio->GetElementType());
t_disp.Invoke<GetRatioDataImpl>(ratio, ratio_data);
}

Expand All @@ -117,7 +117,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const {
const fast_divmod fdm_dim(gsl::narrow_cast<int>(dim));
PhiloxGenerator& generator = generator_ ? *generator_ : PhiloxGenerator::Default();

utils::MLTypeCallDispatcher<ALL_IEEE_FLOAT_DATA_TYPES> t_disp(X->GetElementType());
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(X->GetElementType());
return t_disp.InvokeRet<Status, BiasDropoutComputeImpl>(
GetDeviceProp(), Stream(), N, fdm_dim, ratio_data, generator, *X, *bias, residual, *Y, mask_data, has_same_shape_bias);
}
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/contrib_ops/cuda/math/bias_dropout_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ void BiasDropoutKernelImpl(
SPECIALIZED_BIAS_DROPOUT_IMPL(float)
SPECIALIZED_BIAS_DROPOUT_IMPL(double)
SPECIALIZED_BIAS_DROPOUT_IMPL(half)
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
SPECIALIZED_BIAS_DROPOUT_IMPL(nv_bfloat16)
#endif
SPECIALIZED_BIAS_DROPOUT_IMPL(BFloat16)

} // namespace cuda
} // namespace contrib {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ namespace cuda {
#define CONTRIB_BINARY_OP_HFD(name, ver) \
CONTRIB_BINARY_OP_TYPED(name, ver, MLFloat16) \
CONTRIB_BINARY_OP_TYPED(name, ver, float) \
CONTRIB_BINARY_OP_TYPED(name, ver, double)
CONTRIB_BINARY_OP_TYPED(name, ver, double) \
CONTRIB_BINARY_OP_TYPED(name, ver, BFloat16)

CONTRIB_BINARY_OP_HFD(BiasGelu, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ namespace cuda {
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

// create declarations for op and impl
#define CONTRIB_BINARY_OP_NAME_EXPR(name, expr) \
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/contrib_ops/cuda/math/fused_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ namespace cuda {
REGISTER_KERNEL_TYPED(TransposeMatMul, float)
REGISTER_KERNEL_TYPED(TransposeMatMul, double)
REGISTER_KERNEL_TYPED(TransposeMatMul, MLFloat16)
REGISTER_KERNEL_TYPED(TransposeMatMul, BFloat16)

REGISTER_KERNEL_TYPED(FusedMatMul, float)
REGISTER_KERNEL_TYPED(FusedMatMul, double)
REGISTER_KERNEL_TYPED(FusedMatMul, MLFloat16)

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
REGISTER_KERNEL_TYPED(TransposeMatMul, BFloat16)
REGISTER_KERNEL_TYPED(FusedMatMul, BFloat16)
#endif

} // namespace cuda
} // namespace contrib
Expand Down
18 changes: 6 additions & 12 deletions onnxruntime/contrib_ops/cuda/math/isfinite.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/math/isfinite.h"

#if CUDA_VERSION >= 11000
#include "cuda_bf16.h"
#endif

namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -54,22 +50,20 @@ __device__ __forceinline__ bool IsNaNScalar(const half value) {
#endif
}

#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__device__ __forceinline__ bool IsFiniteScalar(const nv_bfloat16 value) {
return !__hisinf(value) && !__hisnan(value);
__device__ __forceinline__ bool IsFiniteScalar(const BFloat16 value) {
return isfinite(static_cast<float>(value));
}

template <>
__device__ __forceinline__ bool IsInfScalar(const nv_bfloat16 value) {
return __hisinf(value);
__device__ __forceinline__ bool IsInfScalar(const BFloat16 value) {
return isinf(static_cast<float>(value));
}

template <>
__device__ __forceinline__ bool IsNaNScalar(const nv_bfloat16 value) {
return __hisnan(value);
__device__ __forceinline__ bool IsNaNScalar(const BFloat16 value) {
return isnan(static_cast<float>(value));
}
#endif

} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit 44e2db9

Please sign in to comment.