Skip to content

Commit

Permalink
[Relay][AlterOpLayout] NHWC to NCHWc pad operator. (#4103)
Browse files Browse the repository at this point in the history
* [Relay][AlterOpLayout] NHWC to NCHWc pad operator.

* Fixing culprit.

* Flaky test 1.

* Flaky test 2.
  • Loading branch information
anijain2305 authored and yzhliu committed Oct 15, 2019
1 parent bc54310 commit 4ee534b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 1 deletion.
77 changes: 77 additions & 0 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,82 @@ namespace relay {
// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);

Array<Array<Layout> > PadInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());

Layout ret;
// If new_in_layouts are defined, this code tries to modify the layout.
bool is_layout_modified = new_in_layouts.defined();
if (new_in_layouts.defined()) {
// Create a map of axis to param_width. For the new layout, a new param_width is generated using
// the map. The new layout is rejected, if the padding is happening along the axis which was
// split.

// 1) Create a map from axis to param_width using old layout.
std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
int index_counter = 0;
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
for (auto iter_var : old_in_layouts[0]->axes) {
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
index_counter++;
}

// 2) Create new pad width by walking over the new layout and using the map.
tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
auto axis_name = new_layout_axis.name();
if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
// This is primal axis. So, directly use the original pad_width.
new_pad_width.push_back(axis_pad_width.at(axis_name));
} else {
// This is the axis that got split. So, check that pad_width was [0, 0] originally.
const auto& dual_axis = new_layout_axis.ToPrimal();
auto dual_axis_name = dual_axis.name();
CHECK(axis_pad_width.count(dual_axis_name))
<< "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
new_pad_width.push_back(axis_pad_width.at(dual_axis_name));

// If any pad_width element is not zero, do not change the layout.
for (auto width : axis_pad_width.at(dual_axis_name)) {
if (auto* width_imm = width.as<IntImm>()) {
if (width_imm->value != 0) {
is_layout_modified = false;
}
} else {
is_layout_modified = false;
}
}
}
}

// If the above conditions satisfied, we can set the newly created pad_width and use the new
// layout.
if (is_layout_modified) {
ret = new_in_layouts[0];
params->pad_width = new_pad_width;
}
}

if (!is_layout_modified) {
if (old_in_layouts.defined()) {
CHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
} else {
ret = Layout::Undef();
}
}

return Array<Array<Layout> >{{ret}, {ret}};
}

bool PadRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -133,6 +209,7 @@ RELAY_REGISTER_OP("nn.pad")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);

Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/coreml/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
model = cm.models.MLModel(model_file)
x = model_zoo.get_cat_image()
shape_dict = {input_name : x.shape}
mod, params = relay.frontend.from_coreml(model, shape_dict)
# Some Relay passes change operators on the fly. Ensuring that we generate
# new graph for each target.
for target, ctx in ctx_list():
mod, params = relay.frontend.from_coreml(model, shape_dict)
tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))

Expand Down
115 changes: 115 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,120 @@ def expected():
assert(analysis.alpha_equal(a, b))


def test_alter_layout_pad():
""" Check NCHW, NHWC and corner case for pad layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=112)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected_nchw()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

# Check that conversion does not happen when padding along split axis..
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.nn.pad(ret, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
Expand Down Expand Up @@ -815,5 +929,6 @@ def expected_nhwc():
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()

0 comments on commit 4ee534b

Please sign in to comment.