diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 2f7254040475..ccf0931f2480 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -187,6 +187,69 @@ namespace mxnet { namespace common { /*! \brief common utils for cuda */ namespace cuda { +/*! + * \brief Converts between C++ datatypes and enums/constants needed by cuBLAS. + */ +template +struct CublasType; + +// With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own +// datatype cublasDataType_t. The older cudaDataType_t values could be +// included below, but since this class was introduced to support the cuBLAS v8 +// call cublasGemmEx(), burdening the class with the legacy type values +// was not needed. + +template<> +struct CublasType { + static const int kFlag = mshadow::kFloat32; +#if CUDA_VERSION >= 8000 + static const cudaDataType_t kCudaFlag = CUDA_R_32F; +#endif + typedef float ScaleType; + static const float one; + static const float zero; +}; +template<> +struct CublasType { + static const int kFlag = mshadow::kFloat64; +#if CUDA_VERSION >= 8000 + static const cudaDataType_t kCudaFlag = CUDA_R_64F; +#endif + typedef double ScaleType; + static const double one; + static const double zero; +}; +template<> +struct CublasType { + static const int kFlag = mshadow::kFloat16; +#if CUDA_VERSION >= 8000 + static const cudaDataType_t kCudaFlag = CUDA_R_16F; +#endif + typedef float ScaleType; + static const mshadow::half::half_t one; + static const mshadow::half::half_t zero; +}; +template<> +struct CublasType { + static const int kFlag = mshadow::kUint8; +#if CUDA_VERSION >= 8000 + static const cudaDataType_t kCudaFlag = CUDA_R_8I; +#endif + typedef uint8_t ScaleType; + static const uint8_t one = 1; + static const uint8_t zero = 0; +}; +template<> +struct CublasType { + static const int kFlag = mshadow::kInt32; +#if CUDA_VERSION >= 8000 + static const cudaDataType_t kCudaFlag = CUDA_R_32I; +#endif + typedef int32_t ScaleType; + static const int32_t one = 1; + static const int32_t zero = 0; +}; + /*! * \brief Get string representation of cuBLAS errors. * \param error The error. @@ -218,6 +281,17 @@ inline const char* CublasGetErrorString(cublasStatus_t error) { return "Unknown cuBLAS status"; } +#if CUDA_VERSION >= 8000 +/*! + * \brief Create the proper constant for indicating cuBLAS transposition, if desired. + * \param transpose Whether transposition should be performed. + * \return the yes/no transposition-indicating constant. + */ +inline cublasOperation_t CublasTransposeOp(bool transpose) { + return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; +} +#endif + /*! * \brief Get string representation of cuSOLVER errors. * \param error The error. diff --git a/src/common/utils.h b/src/common/utils.h index 2b4b821a1835..b919cb301dff 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } +inline int more_precise_type(const int type1, const int type2) { + if (type1 == type2) return type1; + if (is_float(type1) && is_float(type2)) { + if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) { + return mshadow::kFloat64; + } + if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) { + return mshadow::kFloat32; + } + return mshadow::kFloat16; + } else if (is_float(type1) || is_float(type2)) { + return is_float(type1) ? type1 : type2; + } + if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) { + return mshadow::kInt64; + } + if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) { + return mshadow::kInt32; + } + CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || + (type1 == mshadow::kInt8 && type2 == mshadow::kUint8))) + << "1 is UInt8 and 1 is Int8 should not get here"; + if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) { + return mshadow::kUint8; + } + return mshadow::kInt8; +} + +inline int np_binary_out_type(const int type1, const int type2) { + if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || + (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) { + return mshadow::kInt32; + } + return more_precise_type(type1, type2); +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h index da3d14e33cf4..da48ffa52dca 100644 --- a/src/operator/contrib/transformer-inl.h +++ b/src/operator/contrib/transformer-inl.h @@ -34,6 +34,15 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { + DMLC_DECLARE_FIELD(heads) + .describe("Set number of heads"); + } +}; + template static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 00085c0dc7aa..2ca6f8c71093 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -29,6 +29,276 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) + << "Input attention should be 3D in batch-seq_length-seq_length, " + << "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) + << "queries_keys_values.shape[0] and attention.shape[1] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) + << "queries_keys_values.shape[0] and attention.shape[2] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) + << "queries_keys_values.shape[2] should be a multiple of 3, " + << "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " + << in_shape->size() << " inputs"; + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " + << "currently is " << q_shape.ndim() << "D"; + CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " + << "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(q_shape[2] * 2, kv_shape[2]) + << "keys_values.shape[2] should be equal to queries.shape[2] * 2, " + << "currently are: " << kv_shape[2] << " and " << q_shape[2]; + CHECK_EQ(q_shape[1], kv_shape[1]) + << "queries.shape[1] should be equal to keys_values.shape[1], " + << "currently are: " << q_shape[1] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " + << in_shape->size() << " inputs"; + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape.ndim(), 3U) + << "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " + << "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) + << "Input attention should be 3D in batch-seq_length-seq_length, " + << "currently is " << att_shape.ndim() << "D"; + CHECK_EQ(kv_shape[0], att_shape[2]) + << "keys_values.shape[0] should be equal to attention.shape[2], currently are " + << kv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) << "attention.shape[0] " + << "should be equal to keys_values.shape[1] * heads, currently are: " + << att_shape[2] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2})); + return true; +} + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) +.describe(R"code(Compute the matrix multiplication between the projections of +queries and keys in multihead attention use as self attention. + +the input must be a single tensor of interleaved projections +of queries, keys and values following the layout: +(seq_length, batch_size, num_heads * head_dim * 3) + +the equivalent code would be: +tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) +q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) +q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) +q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) +k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) +k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) +output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) + +This Op is GPU only +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries_keys_values"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulSelfAttQKShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"}) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt) +.describe(R"code(Compute the matrix multiplication between the projections of +values and the attention weights in multihead attention use as self attention. + +the inputs must be a tensor of interleaved projections +of queries, keys and values following the layout: +(seq_length, batch_size, num_heads * head_dim * 3) + +and the attention weights following the layout: +(batch_size, seq_length, seq_length) + +the equivalent code would be: +tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) +v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3)) +v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True) +output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) +output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) +output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) +output = mx.nd.reshape(output, shape=(0, 0, -1)) + +This Op is GPU only +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries_keys_values", "attention"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulSelfAttValAttShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"}) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk) +.describe(R"code(Compute the matrix multiplication between the projections of +queries and keys in multihead attention use as encoder-decoder. + +the inputs must be a tensor of projections of queries following the layout: +(seq_length, batch_size, num_heads * head_dim) + +and a tensor of interleaved projections of values and keys following the layout: +(seq_length, batch_size, num_heads * head_dim * 2) + +the equivalent code would be: +q_proj = mx.nd.transpose(queries, axes=(1, 2, 0, 3)) +q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) +q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) +tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1)) +k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) +k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) +output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) + +This Op is GPU only +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries", "keys_values"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulEncDecQKShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"}) +.add_argument("queries", "NDArray-or-Symbol", "Queries") +.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt) +.describe(R"code(Compute the matrix multiplication between the projections of +values and the attention weights in multihead attention use as encoder-decoder. + +the inputs must be a tensor of interleaved projections of +keys and values following the layout: +(seq_length, batch_size, num_heads * head_dim * 2) + +and the attention weights following the layout: +(batch_size, seq_length, seq_length) + +the equivalent code would be: + +tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) +v_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) +v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True) +output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) +output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) +output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) +output = mx.nd.reshape(output, shape=(0, 0, -1)) + +This Op is GPU only +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"keys_values", "attention"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulEncDecValAttShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"}) +.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + + // relu MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim) .describe(R"code(Rescale the input by the square root of the channel dimension. diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index 6ed073db6011..e152669478dd 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -22,12 +22,572 @@ * \file transformer.cu * \brief GPU implementation of the operators used in Transformer */ + +#include +#include +#include +#include + #include #include "./transformer-inl.h" +#include "../../common/cuda_utils.h" namespace mxnet { namespace op { +// Approach in gemm_switch_fp32accum is coming from MLPerf v0.6 submission repository from NVIDIA +// by https://github.com/kevinstephano +template +void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, + int32_t m, int32_t n, int32_t k, + float alpha, const DType* a, int32_t lda, int32_t strideA, + const DType *b, int32_t ldb, int32_t strideB, float beta, + DType *c, int32_t ldc, int32_t strideC, int32_t batchCount, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { +#if CUDA_VERSION >= 9010 + using namespace mxnet::common::cuda; + CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + + cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); + auto err = CUBLAS_STATUS_SUCCESS; + // TODO(cfujitsang): handle computation_precision + err = cublasGemmStridedBatchedEx( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(&alpha), + a, CublasType::kCudaFlag, static_cast(lda), strideA, + b, CublasType::kCudaFlag, static_cast(ldb), strideB, + reinterpret_cast(&beta), + c, CublasType::kCudaFlag, static_cast(ldc), strideC, + static_cast(batchCount), CUDA_R_32F, algo); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail."; +#else + LOG(FATAL) << "Not implemented with CUDA < 9.1"; +#endif +} + +template +void gemm_switch_fp32accum(mshadow::Stream* s, bool transA, bool transB, + int32_t m, int32_t n, int32_t k, + float alpha, const DType *a, int32_t lda, + int32_t strideA, const DType *b, int32_t ldb, + int32_t strideB, float beta, DType *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + cudaStream_t stream = mshadow::Stream::GetStream(s); + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + } else { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount); + } + CHECK_CUDA_ERROR("Error at InterleavedMatMul"); +} + +// TODO(cfujitsang): use scale as optional ? +void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_dim = inputs[0].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + queries_keys_values, + lead_dim, + batch_stride, + beta, + output, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + }) +} + +void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + DType* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + if (req[0] == kWriteTo) { + cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads, + lead_dim, + batch_stride, + attn_batches); + gemm_switch_fp32accum(s, + false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + head_dim, + lead_dim, + batch_stride, + attn_batches); + }) +} + +void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_dim = inputs[0].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); + }) +} + +void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[2].FlatTo2D(s).dptr_; + DType* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + DType* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + if (req[0] != kNullOp) { + if (req[0] == kWriteTo) { + cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = req[0] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + 2 * head_dim, + lead_dim, + batch_stride, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + } + }) +} + + +void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries = inputs[0].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t q_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_q_dim = inputs[0].shape_[2]; + const int32_t kv_seq_len = inputs[1].shape_[0]; + const int32_t output_lin_kv_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_q_dim; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim_q = attn_batches * head_dim; + const int32_t lead_dim_kv = attn_batches * 2 * head_dim; + const int32_t batch_stride_q = head_dim; + const int32_t batch_stride_kv = head_dim * 2; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + true, + false, + kv_seq_len, + q_seq_len, + head_dim, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + queries, + lead_dim_q, + batch_stride_q, + beta, + output, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + }) +} + +void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries = inputs[1].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[2].FlatTo2D(s).dptr_; + DType* queries_grads = outputs[0].FlatTo2D(s).dptr_; + DType* keys_values_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t q_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_q_dim = inputs[1].shape_[2]; + const int32_t kv_seq_len = inputs[2].shape_[0]; + const int32_t output_lin_kv_dim = inputs[2].shape_[2]; + const int32_t embed_dim = output_lin_q_dim; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim_q = attn_batches * head_dim; + const int32_t lead_dim_kv = attn_batches * 2 * head_dim; + const int32_t batch_stride_q = head_dim; + const int32_t batch_stride_kv = head_dim * 2; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] != kNullOp) { + const float beta = req[0] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + false, + head_dim, + q_seq_len, + kv_seq_len, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + queries_grads, + lead_dim_q, + batch_stride_q, + attn_batches); + } + if (req[1] != kNullOp) { + if (req[1] == kWriteTo) { + cudaMemsetAsync(keys_values_grads, 0, outputs[1].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = req[1] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + kv_seq_len, + q_seq_len, + scale, + queries, + lead_dim_q, + batch_stride_q, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + }) +} + +void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* keys_values = inputs[0].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t kv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_kv_dim = inputs[0].shape_[2]; + const int32_t attn_batches = inputs[1].shape_[0]; + const int32_t q_seq_len = inputs[1].shape_[1]; + const int32_t embed_dim = output_lin_kv_dim / 2; + int32_t head_dim = embed_dim / params.heads; + const int32_t lead_dim_kv = attn_batches * head_dim * 2; + const int32_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + q_seq_len, + kv_seq_len, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); + }) +} + +void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[1].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[2].FlatTo2D(s).dptr_; + DType* keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + DType* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t kv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_kv_dim = inputs[1].shape_[2]; + const int32_t attn_batches = inputs[2].shape_[0]; + const int32_t q_seq_len = inputs[2].shape_[1]; + const int32_t embed_dim = output_lin_kv_dim / 2; + int32_t head_dim = embed_dim / params.heads; + const int32_t lead_dim_kv = attn_batches * head_dim * 2; + const int32_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + + if (req[0] != kNullOp) { + if (req[0] == kWriteTo) { + cudaMemsetAsync(keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = req[0] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + kv_seq_len, + q_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads + head_dim, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + true, + false, + kv_seq_len, + q_seq_len, + head_dim, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + } + }) +} + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) +.set_attr("FCompute", InterleavedMatMulSelfAttQKGPU); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt) +.set_attr("FCompute", InterleavedMatMulSelfAttValAttGPU); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk) +.set_attr("FCompute", InterleavedMatMulEncDecQKGPU); + +NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt) +.set_attr("FCompute", InterleavedMatMulEncDecValAttGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttQKGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttValAttGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecQKGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecValAttGPU); + // relu NNVM_REGISTER_OP(_contrib_div_sqrt_dim) .set_attr("FCompute", DivSqrtDimForward_); diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index d73fa1be54a4..3d81cfc0d967 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -134,8 +134,7 @@ class LeakyReLUOp : public Operator { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape, in_data[leakyrelu::kData].dptr(), in_data[leakyrelu::kGamma].dptr(), out_data[leakyrelu::kOut].dptr()); diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index c5a2b1308c73..1ece97b0efd8 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -132,6 +132,26 @@ struct true_divide : public mxnet_op::tunable { MSHADOW_XINLINE static float Map(DType a, DType b) { return static_cast(a) / static_cast(b); } + +#ifndef _WIN32 + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast(a) / b; + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast(a) / b; + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast(a) / b; + } +#endif }; struct rtrue_divide : public mxnet_op::tunable { @@ -146,6 +166,26 @@ struct rtrue_divide : public mxnet_op::tunable { MSHADOW_XINLINE static float Map(DType a, DType b) { return static_cast(b) / static_cast(a); } + +#ifndef _WIN32 + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return b / static_cast(a); + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return b / static_cast(a); + } + + template::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return b / static_cast(a); + } +#endif }; MXNET_BINARY_MATH_OP_NC(left, a); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 91478660a123..5d297a547c8f 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -471,6 +471,69 @@ struct AccType { {__VA_ARGS__} \ } \ break; \ + case mshadow::kBool: \ + { \ + typedef bool DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kBool: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not bool"; \ + } \ + break; \ default: \ LOG(FATAL) << "Unknown type enum " << type; \ } @@ -783,6 +846,56 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } +#ifndef _WIN32 + /*! \brief inputs are two tensors with a half_t output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, + mshadow::half::half_t *out, + const DType *lhs, + const mshadow::half::half_t *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a float output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a double output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + + /*! \brief inputs are two tensors with a half_t output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, + mshadow::half::half_t *out, + const DType *lhs, + const mshadow::half::half_t value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } + + /*! \brief inputs are two tensors with a float output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } + + /*! \brief inputs are two tensors with a double output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); + } +#endif + /*! \brief inputs are two tensors with a float output tensor */ template::value, int>::type = 0> diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 61239d33800c..1eff5cd8591d 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -394,8 +394,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, in.dptr(), @@ -463,8 +462,7 @@ class DropoutOp { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, grad.dptr(), mask.dptr(), gdata.dptr()); }); diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index f9dbe5bbfd8f..6eda2aa33b34 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -127,7 +127,7 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { } } -static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { +static inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) { if ((x + padl + padr - k) % s != 0) { return (padr + s - ((x + padl + padr - k) % s)); } else { diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index dba10f8b6cd5..575554a25c88 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -41,7 +41,7 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, mkldnn::memory::dims dims(N); mkldnn::memory::dims offsets(N); for (int i = 0; i < N; ++i) { - int s = 0; + dim_t s = 0; if (i < param.begin.ndim() && param.begin[i]) { s = *param.begin[i]; if (s < 0) s += ishape[i]; diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 2ec38d586552..ee9c06d49744 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -73,7 +73,7 @@ class MKLDNNTransposeForward { mkldnn_dims_t strides; mkldnn_dims_t sh; - unsigned int total_stride = 1; + dim_t total_stride = 1; for (int i = data_ndim - 1; i >= 0; i--) { sh[i] = shape[i]; strides[axes[i]] = total_stride; diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 601a0526650c..89da570c133b 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, << "Mask needs to be provided when using softmax with use_length=True."; type = inputs[1].type_flag_; } - MXNET_INT_TYPE_SWITCH(type, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(type, IType, { IType* mask_ptr = nullptr; if (param.use_length.value()) { mask_ptr = inputs[1].dptr(); @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mxnet_op; if (softmax_use_length(attrs)) { - MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, { if (req[1] != kNullOp) { mxnet_op::Kernel::Launch( ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MXNET_INT_TYPE_SWITCH(itype, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, { IType * length_ptr = nullptr; if (softmax_use_length(attrs)) { length_ptr = inputs[2].dptr(); diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index cc74e19aef8f..0bc60a08803e 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -43,30 +43,42 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp || outputs[0].Size() == 0U) return; using namespace mshadow; + using namespace mxnet_op; using namespace mshadow::expr; Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); - if (common::is_float(inputs[0].type_flag_)) { + const TBlob& data = inputs[0]; + const TBlob& out = outputs[0]; + if (out.type_flag_ == data.type_flag_) { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + Kernel, xpu>::Launch( + s, data.Size(), out.dptr(), data.dptr(), DType(alpha)); }); }); } else { +#ifndef _WIN32 CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output " "when input's dtype is " << type_string(inputs[0].type_flag_); MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + Kernel, xpu>::Launch( + s, data.Size(), out.dptr(), data.dptr(), + static_cast(alpha)); }); }); +#else + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(data.Size()), s); + TBlob temp_tblob(temp_tensor); + CastCompute(attrs, ctx, {data}, {kWriteTo}, {temp_tblob}); + TrueDivideScalarCompute(attrs, ctx, {temp_tblob}, req, outputs); +#endif } } -template +template void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -77,66 +89,254 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (common::is_float(inputs[0].type_flag_)) { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Kernel, xpu>::Launch(s, outputs[0].Size(), - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + if (lhs.type_flag_ == rhs.type_flag_) { + // Case when types of the 2 input tensors are the same + if (common::is_float(lhs.type_flag_)) { + // If both are the same floats, normal launch + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); }); } else { - CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output " - "when input's dtype is " - << type_string(inputs[0].type_flag_); - MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Kernel, xpu>::Launch(s, outputs[0].Size(), - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); + // If both are the same integers, output is float32 + CHECK_EQ(out.type_flag_, kFloat32) << "true_divide only supports float32 output " + "when input's dtype is " + << type_string(lhs.type_flag_); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); }); } - }); + } else { +#ifndef _WIN32 + // Non-windows case: no usage of temporary space + // Case when types of the 2 input tensors are different + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // both lhs and rhs are float types, output type is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one is float type, the other is integer type, the output type should be the same as float + CHECK_EQ(out.type_flag_, + common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) + << "This case out type should be same as the float type"; + if (common::is_float(lhs.type_flag_)) { + // lhs is the float one + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), rhs.dptr(), lhs.dptr()); + }); + }); + }); + } else { + // rhs is the float one + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), lhs.dptr(), rhs.dptr()); + }); + }); + }); + } + } else { + // lhs is integer type, rhs is integer type, output type should be float + LOG(ERROR) << "not implemented yet..."; + } +#else + // Windows case: using temp space for casting the type + // Case when types of the 2 input tensors are different + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // both lhs and rhs are float types, output type is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // lhs is float type, rhs is integer type, the output type should be the same as lhs + CHECK_EQ(out.type_flag_, + common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) + << "This case out type should be same as the float type"; + TBlob temp_tblob; + if (common::is_float(lhs.type_flag_)) { + // lhs is the float one + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + TrueDivideElemwiseCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + // rhs is the float one + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + TrueDivideElemwiseCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + // lhs is integer type, rhs is integer type, output type should be float + LOG(ERROR) << "not implemented yet..."; + } +#endif + } } -template +template void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + using namespace mxnet_op; if (outputs[0].shape_.Size() == 0U) return; + CHECK_EQ(inputs.size(), 2U); mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { - TrueDivideElemwiseCompute(attrs, ctx, inputs, req, outputs); + TrueDivideElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; +#ifndef _WIN32 BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - if (common::is_float(inputs[0].type_flag_)) { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); - } else { - CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape rstride = calc_stride(new_rshape.get()); + if (lhs.type_flag_ == rhs.type_flag_) { + // When the both inputs have the same data types + if (common::is_float(lhs.type_flag_)) { + // If both inputs are the same float types, output is the same float type + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } else { + CHECK_EQ(out.type_flag_, mshadow::kFloat32) << "true_divide only supports float32 output when input's dtype is " - << type_string(inputs[0].type_flag_); - MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, { - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); + << type_string(lhs.type_flag_); + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + // If both inputs are the same integer types, output is float type + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } + } else { + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // lhs and rhs have different float types, the output is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one of lhs and rhs is float, the output is the same type as the float one + if (common::is_float(lhs.type_flag_)) { + // lhs is float type, output will be the same float type + CHECK_EQ(lhs.type_flag_, out.type_flag_) + << "lhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape, + rhs.dptr(), lhs.dptr(), out.dptr()); + }); + }); + } else { + // rhs is float type, output will be the same float type + CHECK_EQ(rhs.type_flag_, out.type_flag_) + << "rhs should have the same type as out, infer type broken?"; + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + }); + } + } else { + // lhs and rhs have different integer types, the output is float type + LOG(ERROR) << "not implemented yet..."; + } } }); +#else + if (lhs.type_flag_ == rhs.type_flag_) { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape rstride = calc_stride(new_rshape.get()); + // When the both inputs have the same data types + if (common::is_float(lhs.type_flag_)) { + // If both inputs are the same float types, output is the same float type + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } else { + CHECK_EQ(out.type_flag_, mshadow::kFloat32) + << "true_divide only supports float32 output when input's dtype is " + << type_string(lhs.type_flag_); + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { + // If both inputs are the same integer types, output is float type + Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); + }); + } + }); + } else { + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + // lhs and rhs have different float types, the output is the more precise one + LOG(ERROR) << "not implemented yet..."; + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // one of lhs and rhs is float, the output is the same type as the float one + TBlob temp_tblob; + if (common::is_float(lhs.type_flag_)) { + // lhs is float type, output will be the same float type + CHECK_EQ(lhs.type_flag_, out.type_flag_) + << "lhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + TrueDivideBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + // rhs is float type, output will be the same float type + CHECK_EQ(rhs.type_flag_, out.type_flag_) + << "rhs should have the same type as out, infer type broken?"; + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + TrueDivideBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + // lhs and rhs have different integer types, the output is float type + LOG(ERROR) << "not implemented yet..."; + } + } +#endif } } diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 5a4634c3ff8c..d2135befef42 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -28,26 +28,35 @@ namespace mxnet { namespace op { +int TrueDivideOutType(int ltype, int rtype) { + if (common::is_float(ltype) && common::is_float(rtype)) { + // If both inputs are float, return the one with the higher precision + return common::more_precise_type(ltype, rtype); + } else if (common::is_float(ltype) || common::is_float(rtype)) { + // If only one of the inputs is float, return that float type + return (common::is_float(ltype)) ? ltype : rtype; + } + // If neither of the inputs is float, return the default float32 type + return mshadow::kFloat32; +} + template bool TrueDivideType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(num_inputs)); + CHECK_GT(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 1U); + for (const int dtype : *in_attrs) { if (dtype == -1) return false; } - if (num_inputs == 2) { - const int lhs_dtype = in_attrs->at(0); - const int rhs_dtype = in_attrs->at(1); - CHECK_EQ(lhs_dtype, rhs_dtype) - << "true_divide currently only supports same dtype for dividend and divisor"; - } - if (common::is_float(in_attrs->at(0))) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); - } + + const int lhs_dtype = in_attrs->at(0); + const int rhs_dtype = (num_inputs == 2) ? + in_attrs->at(1) : + (common::is_float(lhs_dtype) ? lhs_dtype : mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, TrueDivideOutType(lhs_dtype, rhs_dtype)); return true; } @@ -64,7 +73,13 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}, {1, 0}}; }) -.set_attr("FCompute", TrueDivideBroadcastCompute) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif +.set_attr("FCompute", TrueDivideBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); @@ -81,6 +96,12 @@ NNVM_REGISTER_OP(_npi_true_divide_scalar) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") @@ -98,6 +119,12 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) +#ifdef _WIN32 +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index c026d689233d..7211f4a0a006 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -29,7 +29,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_true_divide) -.set_attr("FCompute", TrueDivideBroadcastCompute); +.set_attr("FCompute", TrueDivideBroadcastCompute); NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 3d3bcfacbd05..ad06df8d92be 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -187,9 +187,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } namespace mxnet_op { -template +template struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ + template MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, IType *lhs, IType *rhs, @@ -208,6 +209,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + template MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, IType lhs, IType *rhs, @@ -224,6 +226,49 @@ struct binary_broadcast_kernel { KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); } } + +#ifndef _WIN32 + /*! \brief Map function for binary_broadcast_kernel */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType *lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + template::value && + !std::is_pointer::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); + } + } +#endif }; template @@ -307,7 +352,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); @@ -336,7 +381,7 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: + mxnet_op::Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); @@ -444,11 +489,11 @@ void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx, Shape lstride = calc_stride(new_csrshape.get()); Shape rstride = calc_stride(new_dnsshape.get()); if (reverse && std::is_same::value) { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } else { - Kernel, xpu>:: + Kernel, xpu>:: template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape, DType(0), dns_data.dptr(), out_data.dptr()); } @@ -658,7 +703,7 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs) { \ return std::vector{"lhs", "rhs"}; \ }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ + .set_attr("FInferShape", BinaryBroadcastShape) \ .set_attr("FInferType", ElemwiseType<2, 1>) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 02b005eed995..834bbdbfc3d1 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -256,7 +256,7 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 8b6928a2aa39..fe74eed727e5 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2502,13 +2502,327 @@ def test_arange_like_dtype(): x = mx.sym.Variable('x', dtype=t) y = mx.sym.reshape(x, shape=(0, 0, -1)) z = mx.sym.contrib.arange_like(y, axis=-1) - + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) out = mod.forward(is_train=False) for v in out: assert v.dtype == t +@with_seed() +def check_multihead_attention_selfatt(dtype): + def convert_weight(F, q_weight, k_weight, v_weight, num_heads): + q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, q_bias, k_bias, v_bias, num_heads): + q_bias = F.reshape(q_bias, shape=(num_heads, -1)) + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) + qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) + qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) + qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( + qkv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( + qkv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, + 'k_weight': dtype, + 'v_weight': dtype, + 'q_bias': dtype, + 'k_bias': dtype, + 'v_bias': dtype, + 'sonde': dtype}, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + type_dict={'qkv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +def test_multihead_attention_selfatt(): + for dtype in ['float16', 'float32']: + check_multihead_attention_selfatt(dtype=dtype) + +def check_multihead_attention_encdec(dtype): + def convert_weight(F, k_weight, v_weight, num_heads): + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, k_bias, v_bias, num_heads): + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) + kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) + kv = mx.sym.transpose(kv, axes=(1, 0, 2)) + kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, + num_hidden=qkv_units * 2, no_bias=False) + q = mx.sym.transpose(q, axes=(1, 0, 2)) + q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_encdec_qk( + q_proj, kv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt( + kv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + out_weight=(out_dim, qkv_units), + out_bias=(out_dim,), + type_dict={'q': dtype, + 'kv': dtype, + 'q_weight': dtype, + 'q_bias': dtype, + 'k_weight': dtype, + 'k_bias': dtype, + 'v_weight': dtype, + 'v_bias': dtype, + 'out_weight': dtype, + 'out_bias': dtype, + }, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + type_dict={'q': dtype, + 'kv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +def test_multihead_attention_encdec(): + for dtype in ['float16', 'float32']: + check_multihead_attention_encdec(dtype=dtype) if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index a2716fb5363f..c1a6ed567b94 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1940,7 +1940,7 @@ def get_new_shape(shape, axis): with mx.autograd.record(): y = test_concat(a, b, c, d) - + assert y.shape == expected_ret.shape assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) @@ -3735,12 +3735,14 @@ def test_np_true_divide(): [(2, 3, 1), (1, 4)], [(2, 1, 4, 1), (3, 1, 5)], ] - dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + dtypes = [np.bool, np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + itypes = [np.bool, np.int8, np.uint8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] for shape_pair, dtype in itertools.product(shapes, dtypes): a = np.random.uniform(3, 50, size=shape_pair[0]).astype(dtype) b = np.random.uniform(3, 50, size=shape_pair[-1]).astype(dtype) out_mx = a / b - if _np.issubdtype(dtype, _np.integer): + if _np.issubdtype(dtype, _np.integer) or (dtype is np.bool): assert out_mx.dtype == np.float32 else: assert out_mx.dtype == dtype @@ -3756,6 +3758,20 @@ def test_np_true_divide(): out_np = _np.true_divide(val, a.asnumpy()) assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + for shape_pair, itype, ftype in itertools.product(shapes, itypes, ftypes): + i_ = np.random.uniform(3, 50, size=shape_pair[0]).astype(itype) + f_ = np.random.uniform(3, 50, size=shape_pair[-1]).astype(ftype) + + out_mx = i_ / f_ + assert out_mx.dtype == ftype + out_np = _np.true_divide(i_.asnumpy(), f_.asnumpy()) + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + + out_mx = f_ / i_ + assert out_mx.dtype == ftype + out_np = _np.true_divide(f_.asnumpy(), i_.asnumpy()) + assert_almost_equal(out_mx.asnumpy(), out_np, rtol=1e-3, atol=1e-3, use_broadcast=False) + @with_seed() @use_np