Skip to content

Commit

Permalink
Mixed data type binary ops (apache#16699)
Browse files Browse the repository at this point in the history
* support mixed-precision binary operations

* improvement for documentations and error messages
  • Loading branch information
haojin2 authored and ptrendx committed Nov 15, 2019
1 parent e896b98 commit 7141e9a
Show file tree
Hide file tree
Showing 11 changed files with 745 additions and 21 deletions.
40 changes: 40 additions & 0 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,14 @@ def add(x1, x2, out=None, **kwargs):
-------
add : ndarray or scalar
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out)

Expand All @@ -548,6 +556,14 @@ def subtract(x1, x2, out=None, **kwargs):
-------
subtract : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar,
_npi.rsubtract_scalar, out)
Expand Down Expand Up @@ -575,6 +591,14 @@ def multiply(x1, x2, out=None, **kwargs):
out : ndarray or scalar
The multiplication of x1 and x2, element-wise. This is a scalar if both x1 and x2
are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out)

Expand Down Expand Up @@ -602,6 +626,14 @@ def divide(x1, x2, out=None, **kwargs):
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
Expand Down Expand Up @@ -632,6 +664,14 @@ def true_divide(x1, x2, out=None):
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
Expand Down
40 changes: 40 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,14 @@ def add(x1, x2, out=None, **kwargs):
add : ndarray or scalar
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.add(1.0, 4.0)
Expand Down Expand Up @@ -2437,6 +2445,14 @@ def subtract(x1, x2, out=None, **kwargs):
subtract : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.subtract(1.0, 4.0)
Expand Down Expand Up @@ -2473,6 +2489,14 @@ def multiply(x1, x2, out=None, **kwargs):
out : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.multiply(2.0, 4.0)
Expand Down Expand Up @@ -2511,6 +2535,14 @@ def divide(x1, x2, out=None, **kwargs):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
Examples
--------
>>> np.true_divide(x, 4)
Expand Down Expand Up @@ -2545,6 +2577,14 @@ def true_divide(x1, x2, out=None):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
Examples
--------
>>> x = np.arange(5)
Expand Down
30 changes: 27 additions & 3 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector<int>& ndstypes,
return false;
}

inline std::string dtype_string(const int dtype) {
switch (dtype) {
case mshadow::kFloat32:
return "float";
case mshadow::kFloat64:
return "double";
case mshadow::kFloat16:
return "half";
case mshadow::kUint8:
return "unsigned char";
case mshadow::kInt8:
return "char";
case mshadow::kInt32:
return "int";
case mshadow::kInt64:
return "long long";
case mshadow::kBool:
return "bool";
default:
LOG(FATAL) << "Unknown type enum " << dtype;
}
return "unknown";
}

/*! \brief get string representation of dispatch_mode */
inline std::string dispatch_mode_string(const DispatchMode x) {
switch (x) {
Expand Down Expand Up @@ -842,7 +866,7 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
inline int get_more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
Expand Down Expand Up @@ -870,12 +894,12 @@ inline int more_precise_type(const int type1, const int type2) {
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
inline int np_binary_out_infer_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
return get_more_precise_type(type1, type2);
}

} // namespace common
Expand Down
94 changes: 94 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,100 @@ MXNET_BINARY_MATH_OP_NC(right, b);

MXNET_BINARY_MATH_OP_NC(mul, a * b);

#ifndef _WIN32
struct mixed_plus {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) + b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) + b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) + b;
}
};

struct mixed_minus {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) - b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) - b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) - b;
}
};

struct mixed_rminus {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return b - static_cast<mshadow::half::half_t>(a);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return b - static_cast<float>(a);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return b - static_cast<double>(a);
}
};

struct mixed_mul {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) * b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) * b;
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) * b;
}
};
#endif

MXNET_BINARY_MATH_OP_NC(div, a / b);

MXNET_BINARY_MATH_OP_NC(plus, a + b);
Expand Down
14 changes: 10 additions & 4 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,14 +859,17 @@ struct op_with_req {

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}
Expand All @@ -883,14 +886,17 @@ struct op_with_req {

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}
Expand Down
Loading

0 comments on commit 7141e9a

Please sign in to comment.