Skip to content

Commit

Permalink
[PASS] Improve GraphFuse to include five patterns (apache#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 2e9b6b9 commit 2b3d2e2
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 46 deletions.
13 changes: 10 additions & 3 deletions nnvm/docs/top.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
NNVM Core Primitives
====================
NNVM Core Tensor Operators
==========================

**Level 1: Basic Ops**
**Level 1: Basic Operators**
This level enables fully connected multi-layer perceptron.

.. autosummary::
:nosignatures:
Expand All @@ -12,12 +13,14 @@ NNVM Core Primitives
nnvm.symbol.sigmoid
nnvm.symbol.exp
nnvm.symbol.log
nnvm.symbol.sqrt
nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
Expand All @@ -27,6 +30,8 @@ NNVM Core Primitives

**Level 2: Convolutions**

This level enables typical convnet models.

.. autosummary::
:nosignatures:

Expand Down Expand Up @@ -78,12 +83,14 @@ NNVM Core Primitives
.. autofunction:: nnvm.symbol.sigmoid
.. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
Expand Down
19 changes: 13 additions & 6 deletions nnvm/include/nnvm/compiler/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ using ::tvm::Tensor;
using ::tvm::Schedule;

/*! \brief operator pattern used in graph fusion */
enum OpPatternKind : int {
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcast operation
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Complex operation, can fuse bcast in input/outputs
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kComplex = 2,
// Extern operation, cannot fuse anything.
kExtern = 3
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};

/*! \brief the operator pattern */
Expand Down
22 changes: 17 additions & 5 deletions nnvm/python/nnvm/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,24 @@
import tvm

class OpPattern(object):
ELEM_WISE = 0
"""Operator generic patterns
See Also
--------
top.tag : Contains explaination of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Complex means we can fuse elemwise to it
COMPLEX = 2
# Extern means the op is not fusable
EXTERN = 3
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8

_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
Expand Down
18 changes: 13 additions & 5 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _):
return topi.nn.relu(inputs[0])

reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEM_WISE)
reg.register_pattern("relu", OpPattern.ELEMWISE)

# leaky_relu
@reg.register_compute("leaky_relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0])

reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)

# flatten
@reg.register_compute("flatten")
Expand All @@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _):
return topi.nn.flatten(inputs[0])

reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.COMPLEX)
reg.register_pattern("flatten", OpPattern.INJECTIVE)


# softmax
Expand All @@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target):
return tvm.create_schedule([x.op for x in outs])

# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.EXTERN)
reg.register_pattern("softmax", OpPattern.OPAQUE)


# dense
Expand All @@ -67,7 +75,7 @@ def schedule_dense(_, outs, target):
return tvm.create_schedule([x.op for x in outs])

# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.EXTERN)
reg.register_pattern("dense", OpPattern.OPAQUE)


# conv
Expand Down Expand Up @@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("conv2d", OpPattern.COMPLEX)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
40 changes: 22 additions & 18 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from ..compiler import registry as reg
from ..compiler import OpPattern

def _schedule_broadcast(_, outs, target):
def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
return topi.cuda.schedule_injective(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s

def _compute_binary_scalar(f):
Expand Down Expand Up @@ -42,89 +44,91 @@ def _compute(attrs, x, _):
return _compute


_fschedule_broadcast = tvm.convert(_schedule_broadcast)
_fschedule_injective = tvm.convert(_schedule_injective)
_fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective

# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE)
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)

# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)

# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast)

# log
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEM_WISE)
reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast)

# tanh
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast)

# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEM_WISE)
reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast)

# sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)

# add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)

# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)

# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)

# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)

# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)

# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)

# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)

# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)

# elemwise_add
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info):
oshape = out_info[0].shape
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.COMPLEX)
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_broadcast)
31 changes: 24 additions & 7 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
ref_count[e.node_id] += 2;
}
// Pattern for the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
// Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Master node id of fusion segment.
Expand All @@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
if (inode.source->is_variable()) {
fuse_vec[nid] = FuseRule::kRealize; continue;
}
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);
TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque);

if (pt <= kBroadcast) {
// Try to check if we can fuse to the master.
int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false;
if (ipt <= kBroadcast) {
if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else if (ipt == kComplex && chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
} else if (ipt == kOutEWiseFusable &&
chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
Expand All @@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
master_vec[nid] = chosen_master;
if (chosen_master != -1) {
pt = kComplex;
pt = kOutEWiseFusable;
} else {
pt = ewise ? kElemWise : kBroadcast;
}
} else if (pt == kInjective || pt == kCommReduce) {
// fuse to the comm reduce or injective
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
fuse_vec[e.node_id] = FuseRule::kRealize;
}
}
}
if (pt == kCommReduce) {
master_vec[nid] = nid;
}
} else {
// realize
master_vec[nid] = nid;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
Expand All @@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}


// point to the group root id of each node
std::vector<int> group_vec(idx.num_nodes(), -1);
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/compiler/layout_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {

// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
Expand Down
Loading

0 comments on commit 2b3d2e2

Please sign in to comment.