Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support mixed-precision binary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 3, 2019
1 parent e139442 commit 28f07a7
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
inline int get_the_more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
Expand Down Expand Up @@ -870,12 +870,12 @@ inline int more_precise_type(const int type1, const int type2) {
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
inline int np_binary_out_infer_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
return get_the_more_precise_type(type1, type2);
}

} // namespace common
Expand Down
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ MXNET_BINARY_MATH_OP_NC(right, b);

MXNET_BINARY_MATH_OP_NC(mul, a * b);

#ifndef _WIN32
struct mixed_mul {
template<typename DType,
typename std::enable_if<!std::is_pointer<DType>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(bool a, DType b) {
return static_cast<DType>(a) * b;
}
};
#endif

MXNET_BINARY_MATH_OP_NC(div, a / b);

MXNET_BINARY_MATH_OP_NC(plus, a + b);
Expand Down
72 changes: 68 additions & 4 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
* \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator.
*/

#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"
#include "./np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -55,6 +54,61 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int ltype = in_attrs->at(0);
const int rtype = in_attrs->at(1);
if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
// Only when both input types are known and not the same, we enter the mixed-precision mode
TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype));
} else {
return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
}
return true;
}

#ifdef _WIN32
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#else
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#endif

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
Expand All @@ -64,8 +118,18 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
Expand Down
16 changes: 13 additions & 3 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
* \file np_elemwise_broadcast_op.cu
* \brief GPU Implementation of basic functions for elementwise binary broadcast operator.
*/
#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"

#include "./np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {
Expand All @@ -35,7 +35,17 @@ NNVM_REGISTER_OP(_npi_subtract)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);

NNVM_REGISTER_OP(_npi_multiply)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
Expand Down
181 changes: 181 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_elemwise_binary_op.h
* \brief
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_

#include <vector>

#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../tensor/elemwise_binary_scalar_op.h"

namespace mxnet {
namespace op {

#ifndef _WIN32
template<typename xpu, typename LOP, typename ROP>
void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
<< "now supports bool with another type only";

Stream<xpu> *s = ctx.get_stream<xpu>();

MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
if (lhs.type_flag_ == kBool) {
Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(
s, size, out.dptr<DType>(), lhs.dptr<bool>(), rhs.dptr<DType>());
} else {
Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(
s, size, out.dptr<DType>(), rhs.dptr<bool>(), lhs.dptr<DType>());
}
}
});
});
}
#endif

template<typename xpu, typename OP, typename LOP, typename ROP>
void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

CHECK((lhs.type_flag_ == mshadow::kBool) || (rhs.type_flag_ == mshadow::kBool))
<< "now supports bool with another type only";

#ifndef _WIN32
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
if (lhs.type_flag_ == mshadow::kBool) {
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
lhs.dptr<bool>(), rhs.dptr<DType>(), out.dptr<DType>());
} else {
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape,
rhs.dptr<bool>(), lhs.dptr<DType>(), out.dptr<DType>());
}
});
});
}
#else
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
LOG(ERROR) << "not implemented yet...";
} else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
TBlob temp_tblob;
// one is float, the other is bool
CHECK_EQ(out.type_flag_,
common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_)
<< "This case out type should be same as the float type";
if (common::is_float(lhs.type_flag_)) {
MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, {
Tensor<xpu, 1, LType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
Tensor<xpu, 1, RType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
LOG(ERROR) << "not implemented yet...";
}
#endif
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

const TBlob& lhs = inputs[1];
const TBlob& rhs = inputs[2];
if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
return;
}

LOG(ERROR) << "Binary operation with mixed input data types does not support backward yet...";
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2 changes: 1 addition & 1 deletion src/operator/numpy/np_true_divide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace op {
int TrueDivideOutType(int ltype, int rtype) {
if (common::is_float(ltype) && common::is_float(rtype)) {
// If both inputs are float, return the one with the higher precision
return common::more_precise_type(ltype, rtype);
return common::get_the_more_precise_type(ltype, rtype);
} else if (common::is_float(ltype) || common::is_float(rtype)) {
// If only one of the inputs is float, return that float type
return (common::is_float(ltype)) ? ltype : rtype;
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ class ElemwiseBinaryOp : public OpBase {
return a1.var() == a2.var();
}

public:
/*! \brief Minimum of three */
static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
}

private:
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
Expand Down
Loading

0 comments on commit 28f07a7

Please sign in to comment.