Skip to content

Commit

Permalink
Merge pull request #283 from cmpfeil/gt_bfloat16_t
Browse files Browse the repository at this point in the history
Types gt::bfloat16_t and gt::complex_bfloat16_t
  • Loading branch information
germasch authored Nov 29, 2023
2 parents 81d59f7 + d0cfe6a commit 99f3fa4
Show file tree
Hide file tree
Showing 8 changed files with 1,498 additions and 5 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ option(GTENSOR_BOUNDS_CHECK "Enable per access bounds checking" OFF)
option(GTENSOR_ADDRESS_CHECK "Enable address checking for device spans" OFF)
option(GTENSOR_SYNC_KERNELS "Enable host sync after assign and launch kernels" OFF)
option(GTENSOR_ENABLE_FP16 "Enable 16-bit floating point type gt::float16_t" OFF)
option(GTENSOR_ENABLE_BF16 "Enable 16-bit floating point type gt::bfloat16_t" OFF)

if (GTENSOR_ENABLE_FORTRAN)
# do this early (here) since later the `enable_language(Fortran)` gives me trouble
Expand Down Expand Up @@ -343,6 +344,13 @@ if (GTENSOR_ENABLE_FP16)
INTERFACE GTENSOR_ENABLE_FP16)
endif()

if (GTENSOR_ENABLE_BF16)
message(STATUS "${PROJECT_NAME}: gt::bfloat16_t is ENABLED")
message(STATUS "${PROJECT_NAME}: gt::complex_bfloat16_t is ENABLED")
target_compile_definitions(gtensor_${GTENSOR_DEVICE}
INTERFACE GTENSOR_ENABLE_BF16)
endif()


target_compile_definitions(gtensor_${GTENSOR_DEVICE} INTERFACE
GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT=${GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT})
Expand Down
227 changes: 227 additions & 0 deletions include/gtensor/bfloat16_t.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#ifndef GTENSOR_BFLOAT16T_H
#define GTENSOR_BFLOAT16T_H

#include <cmath>
#include <iostream>

#if __has_include(<cuda_bf16.h>)
#include <cuda_bf16.h>
#define GTENSOR_BF16_CUDA_HEADER
#elif 0 // TODO check if other bf16 type available
#else
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!"
#endif

#include "macros.h"

namespace gt
{

// ======================================================================
// bfloat16_t

class bfloat16_t
{

#if defined(GTENSOR_BF16_CUDA_HEADER)
using storage_type = __nv_bfloat16;
#else
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!"
#endif

#if defined(GTENSOR_BF16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 800)
using compute_type = __nv_bfloat16;
#define BFLOAT16T_ON_CUDA_DEVICE
#else
using compute_type = float;
#endif

public:
bfloat16_t() = default;
GT_INLINE bfloat16_t(float x) : x(x){};
GT_INLINE bfloat16_t(storage_type x) : x(x){};

GT_INLINE const bfloat16_t& operator=(const float f)
{
x = f;
return *this;
}
GT_INLINE compute_type Get() const { return static_cast<compute_type>(x); }

// update operators [+=, -=, *=, /=]
GT_INLINE bfloat16_t operator+=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x += y.Get();
#else
x = this->Get() + y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator-=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x -= y.Get();
#else
x = this->Get() - y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator*=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x *= y.Get();
#else
x = this->Get() * y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator/=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x /= y.Get();
#else
x = this->Get() / y.Get();
#endif
return *this;
}

private:
storage_type x;
};

// op is unary [+, -]
#define PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(op) \
GT_INLINE bfloat16_t operator op(const bfloat16_t& rhs) \
{ \
return bfloat16_t(op rhs.Get()); \
}

PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(+);
PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(-);

// op is binary [+, -, *, /]
#define PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op) \
GT_INLINE bfloat16_t operator op(const bfloat16_t& lhs, \
const bfloat16_t& rhs) \
{ \
return bfloat16_t(lhs.Get() op rhs.Get()); \
}

PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/);

// op is binary [+, -, *, /]
// fp_type is [float, double]
#define PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op, fp_type) \
\
GT_INLINE fp_type operator op(const bfloat16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE fp_type operator op(const fp_type& lhs, const bfloat16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, float);

PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, double);

// op is binary [==, !=, <, <=, >, >=]
#define PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(op) \
GT_INLINE bool operator op(const bfloat16_t& lhs, const bfloat16_t& rhs) \
{ \
return lhs.Get() op rhs.Get(); \
}

// op is binary [==, !=, <, <=, >, >=]
// fp_type is [float, double]
#define PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(op, fp_type) \
\
GT_INLINE bool operator op(const bfloat16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE bool operator op(const fp_type& lhs, const bfloat16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(==);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(!=);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<=);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>=);

PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, float);

PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, double);

// op is [==, !=]
// int_type is [int]
#define PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(op, int_type) \
\
GT_INLINE bool operator op(const bfloat16_t& lhs, const int_type& rhs) \
{ \
return lhs op static_cast<float>(rhs); \
} \
\
GT_INLINE bool operator op(const int_type& lhs, const bfloat16_t& rhs) \
{ \
return static_cast<float>(lhs) op rhs; \
}

PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(==, int);
PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(!=, int);

// function is sqrt
GT_INLINE bfloat16_t sqrt(const bfloat16_t& x)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
return hsqrt(x.Get());
#else
return std::sqrt(x.Get());
#endif
}

std::ostream& operator<<(std::ostream& s, const bfloat16_t& h)
{
s << static_cast<float>(h.Get());
return s;
}

} // namespace gt

#undef GTENSOR_BF16_CUDA_HEADER
#undef BFLOAT16T_ON_CUDA_DEVICE
#undef PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR
#undef PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_BFLOAT16T_COMPARISON_OPERATOR
#undef PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR
#undef PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR

#endif // GTENSOR_BFLOAT16T_H
Loading

0 comments on commit 99f3fa4

Please sign in to comment.