Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Backport to 1.6 (#16773, #16781, #16783, #16716, #16699, #16728, #16769
Browse files Browse the repository at this point in the history
…, #16792) (#16832)

* Fix nightly build (#16773)

* Remove dependency on tvmop.conf

* Fix binaries dependencies for ni nightly

* Add comments

* Update tvmop.py

* Fix rebase

* Fix (#16781)

* Speed fused_op compilation by caching ptx and jit-compiled functions (#16783)

* [Numpy] Fix collect_params().zero_grad() in gluon numpy interface (#16716)

* fix zero_grad

* Update parameter.py

* add test

* fix

* Mixed data type binary ops (#16699)

* support mixed-precision binary operations

* improvement for documentations and error messages

* Support boolean elemwise/broadcast binary add, multiply and true_divide (#16728)

* support pure boolean elemwise/broadcast binary op

* switch to unique_tpr

* fix the test error

* Fix rtrue_divide grad (#16769)

* Fix rtrue_divide_scalar

* More tests

* Fix numpy-compatible mean output type for integer inputs (#16792)

* fix mean output type for integer inputs

* enable for windows
  • Loading branch information
ptrendx authored Nov 16, 2019
1 parent 867c98d commit 9abb151
Show file tree
Hide file tree
Showing 32 changed files with 1,390 additions and 291 deletions.
12 changes: 8 additions & 4 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from collections import OrderedDict, defaultdict
import warnings
import numpy as np
import mxnet as mx

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, context
Expand Down Expand Up @@ -896,15 +895,20 @@ def zero_grad(self):
continue
for g in p.list_grad():
if g.stype == 'row_sparse':
mx.ndarray.zeros_like(g, out=g)
ndarray.zeros_like(g, out=g)
else:
arrays[g.context].append(g)

if len(arrays) == 0:
return

for arr in arrays.values():
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
if is_np_array():
for arr in arrays.values():
for ele in arr:
ele[()] = 0
else:
for arr in arrays.values():
ndarray.reset_arrays(*arr, num_arrays=len(arr))

def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts.
Expand Down
40 changes: 40 additions & 0 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,14 @@ def add(x1, x2, out=None, **kwargs):
-------
add : ndarray or scalar
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out)

Expand All @@ -548,6 +556,14 @@ def subtract(x1, x2, out=None, **kwargs):
-------
subtract : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar,
_npi.rsubtract_scalar, out)
Expand Down Expand Up @@ -575,6 +591,14 @@ def multiply(x1, x2, out=None, **kwargs):
out : ndarray or scalar
The multiplication of x1 and x2, element-wise. This is a scalar if both x1 and x2
are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out)

Expand Down Expand Up @@ -602,6 +626,14 @@ def divide(x1, x2, out=None, **kwargs):
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
Expand Down Expand Up @@ -632,6 +664,14 @@ def true_divide(x1, x2, out=None):
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
Expand Down
40 changes: 40 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,14 @@ def add(x1, x2, out=None, **kwargs):
add : ndarray or scalar
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.add(1.0, 4.0)
Expand Down Expand Up @@ -2437,6 +2445,14 @@ def subtract(x1, x2, out=None, **kwargs):
subtract : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.subtract(1.0, 4.0)
Expand Down Expand Up @@ -2473,6 +2489,14 @@ def multiply(x1, x2, out=None, **kwargs):
out : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.multiply(2.0, 4.0)
Expand Down Expand Up @@ -2511,6 +2535,14 @@ def divide(x1, x2, out=None, **kwargs):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
Examples
--------
>>> np.true_divide(x, 4)
Expand Down Expand Up @@ -2545,6 +2577,14 @@ def true_divide(x1, x2, out=None):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Notes
-----
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
Examples
--------
>>> x = np.arange(5)
Expand Down
14 changes: 10 additions & 4 deletions python/mxnet/tvmop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if Features().is_enabled("TVM_OP"):
import json
import logging

from ._ctypes.space import _set_tvm_op_config
from .base import check_call, _LIB, c_str
Expand All @@ -31,7 +32,12 @@
check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))

# op sch config
_CONF_TVM_OP = find_conf_path("tvmop")
with open(_CONF_TVM_OP[0], "r") as f:
ret = ConfigSpaces.from_json_dict(json.load(f))
_set_tvm_op_config(ret)
try:
_CONF_TVM_OP = find_conf_path("tvmop")
except RuntimeError as e:
logging.warning("TVM config file missing, falling back to default schedule", exc_info=True)
else:
logging.info("TVM op config has been loaded")
with open(_CONF_TVM_OP[0], "r") as f:
ret = ConfigSpaces.from_json_dict(json.load(f))
_set_tvm_op_config(ret)
30 changes: 27 additions & 3 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector<int>& ndstypes,
return false;
}

inline std::string dtype_string(const int dtype) {
switch (dtype) {
case mshadow::kFloat32:
return "float";
case mshadow::kFloat64:
return "double";
case mshadow::kFloat16:
return "half";
case mshadow::kUint8:
return "unsigned char";
case mshadow::kInt8:
return "char";
case mshadow::kInt32:
return "int";
case mshadow::kInt64:
return "long long";
case mshadow::kBool:
return "bool";
default:
LOG(FATAL) << "Unknown type enum " << dtype;
}
return "unknown";
}

/*! \brief get string representation of dispatch_mode */
inline std::string dispatch_mode_string(const DispatchMode x) {
switch (x) {
Expand Down Expand Up @@ -842,7 +866,7 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
inline int get_more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
Expand Down Expand Up @@ -870,12 +894,12 @@ inline int more_precise_type(const int type1, const int type2) {
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
inline int np_binary_out_infer_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
return get_more_precise_type(type1, type2);
}

} // namespace common
Expand Down
27 changes: 15 additions & 12 deletions src/executor/pointwise_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,7 @@ namespace {
auto node = nnvm::Node::Create();
subgraph_sym.outputs = subgraph.outputs;
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
std::ostringstream name_oss;
// the name of the new node will be the concatenation of all the node names in the subgraph
DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
if (n->op() != nullptr)
name_oss << n->op()->name << "_";
});
auto subgraph_name = name_oss.str();
subgraph_name.pop_back();
node->attrs.name = subgraph_name;
node->attrs.name = "FusedOp";
node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
node->attrs.op = Op::Get("_FusedOp");
Expand Down Expand Up @@ -152,16 +144,16 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub
auto it = node->control_deps.begin();
static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
std::vector<nnvm::NodePtr> new_control_deps;
while (it != node->control_deps.end()) {
// Use the first control dependency to get the inferattr helper
if (it != node->control_deps.end()) {
if (subgraph_set.count(it->get())) {
new_control_deps.push_back(*it);
} else {
if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
uint32_t node_id = subgraph_node->control_deps.size();
subgraph_node->control_deps.push_back(*it);
auto helper_node = op::MakeNode("_FusedOpOutHelper",
subgraph_node->attrs.name + "_"
+ node->attrs.name + "_outhelper",
"FusedOp_" + node->attrs.name + "_outhelper",
nullptr,
nullptr,
nullptr);
Expand All @@ -180,6 +172,17 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub
}
});

std::ostringstream name_oss;
// the name of the new node will be the concatenation of all the node names in the subgraph
DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
if (n->op() != nullptr) {
name_oss << n->op()->name << "_";
}
});
auto subgraph_name = name_oss.str();
subgraph_name.pop_back();
subgraph_node->attrs.name = subgraph_name;

const auto& index = subgraph.indexed_graph();
DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::NodePtr& node) {
for (auto &e : node->control_deps) {
Expand Down
2 changes: 1 addition & 1 deletion src/ndarray/ndarray_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
template<>
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
});
}
Expand Down
4 changes: 1 addition & 3 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,11 +982,9 @@ const char kernel_begin[] = R"code(
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < N; i+= gridDim.x * blockDim.x) {
int offset = i*nvec;
)code";

const char kernel_end[] = R"code(
}
const char kernel_end[] = R"code(}
}
)code";

Expand Down
Loading

0 comments on commit 9abb151

Please sign in to comment.