Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Backport of #16711, #16737, #16408 to 1.6 branch (#16763)
Browse files Browse the repository at this point in the history
* support mixed-precision true_divide (#16711)

* [MKLDNN] use dim_t instead of int in slice/transpose operators (#16737)

* use dim_t instead of int

* fix same issue in pooling

* rebase code

* trigger CI

* Add MXNet Ops for fast multihead attention (#16408)

* add MXNet Ops for fast multihead attention

* add cutlass as 3rdparty dependency

* add cutlass to compilation flags

* remove all cutlass stuff

* add better error message and description and remove cutlass from compilation flags

* change credit for the approach since the code have changed

* fix typos

* correct another typo

* Add all the cuda/cublas helper functions

* remove tests using kAddTo

* only use cublasStridedBatchedGemm if CUDA >= 9.1

* add equivalent mxnet code in description of mha ops

* remove a wrong copy-paste

* add _contrib for namespace and add GPU only on description

* add warning in bwd_ignore_zero_init description, also test with fp32

* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO

* remove std::move for clang

* remove bwd_ignore_zero_init flag

* remove bwd_ignore_zero_init in test_operator_gpu.py

* fix typo

* fix another typo

* Removed unrelated test
  • Loading branch information
ptrendx authored Nov 8, 2019
1 parent 1aa1b5a commit b1aba6a
Show file tree
Hide file tree
Showing 20 changed files with 1,779 additions and 78 deletions.
74 changes: 74 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,69 @@ namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
namespace cuda {
/*!
* \brief Converts between C++ datatypes and enums/constants needed by cuBLAS.
*/
template<typename DType>
struct CublasType;

// With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own
// datatype cublasDataType_t. The older cudaDataType_t values could be
// included below, but since this class was introduced to support the cuBLAS v8
// call cublasGemmEx(), burdening the class with the legacy type values
// was not needed.

template<>
struct CublasType<float> {
static const int kFlag = mshadow::kFloat32;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_32F;
#endif
typedef float ScaleType;
static const float one;
static const float zero;
};
template<>
struct CublasType<double> {
static const int kFlag = mshadow::kFloat64;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_64F;
#endif
typedef double ScaleType;
static const double one;
static const double zero;
};
template<>
struct CublasType<mshadow::half::half_t> {
static const int kFlag = mshadow::kFloat16;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_16F;
#endif
typedef float ScaleType;
static const mshadow::half::half_t one;
static const mshadow::half::half_t zero;
};
template<>
struct CublasType<uint8_t> {
static const int kFlag = mshadow::kUint8;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_8I;
#endif
typedef uint8_t ScaleType;
static const uint8_t one = 1;
static const uint8_t zero = 0;
};
template<>
struct CublasType<int32_t> {
static const int kFlag = mshadow::kInt32;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_32I;
#endif
typedef int32_t ScaleType;
static const int32_t one = 1;
static const int32_t zero = 0;
};

/*!
* \brief Get string representation of cuBLAS errors.
* \param error The error.
Expand Down Expand Up @@ -218,6 +281,17 @@ inline const char* CublasGetErrorString(cublasStatus_t error) {
return "Unknown cuBLAS status";
}

#if CUDA_VERSION >= 8000
/*!
* \brief Create the proper constant for indicating cuBLAS transposition, if desired.
* \param transpose Whether transposition should be performed.
* \return the yes/no transposition-indicating constant.
*/
inline cublasOperation_t CublasTransposeOp(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}
#endif

/*!
* \brief Get string representation of cuSOLVER errors.
* \param error The error.
Expand Down
36 changes: 36 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
return mshadow::kFloat64;
}
if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
return mshadow::kFloat32;
}
return mshadow::kFloat16;
} else if (is_float(type1) || is_float(type2)) {
return is_float(type1) ? type1 : type2;
}
if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
return mshadow::kInt64;
}
if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
return mshadow::kInt32;
}
CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)))
<< "1 is UInt8 and 1 is Int8 should not get here";
if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
return mshadow::kUint8;
}
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
9 changes: 9 additions & 0 deletions src/operator/contrib/transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
namespace mxnet {
namespace op {

struct InterleavedMatMulParam : public dmlc::Parameter<InterleavedMatMulParam> {
int heads;
bool bwd_ignore_zero_init;
DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) {
DMLC_DECLARE_FIELD(heads)
.describe("Set number of heads");
}
};

template<typename xpu>
static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
Loading

0 comments on commit b1aba6a

Please sign in to comment.