Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LayoutRewriter #10118

Merged
merged 5 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Respect input layout, if explicitly specified (for example, "NW").
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) {
return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs);
}
return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
}

Expand Down Expand Up @@ -283,14 +279,6 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
const Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
// Respect input layout, if explicitly specified (for example, "NW").
// However, a packed layout such as "NC8c" is not supported by dense_pack op. For such cases,
// we insert a layout transform "NC8c" -> "NC".
// We do not expect to get a packed layout like "NW8w", which is not compatitble with "NC",
// since packing is always done on the "C" axis.
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) {
return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs);
}
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need DenseInferCorrectLayout and DensePackInferCorrectLayout now? I added these functions to workaround alter layout issues, but that might not be necessary anymore. Can you try remove them and see what happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skimmed through your change in 66ac470, and I think it is probably still needed. Mainly because when FTVMAlterOpLayout is defined but FInterCorrectLayout is not, the current code logic in

// If there is no FInferCorrectLayout for the type, then we just assume the layout is correct.
static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
if (Op::HasAttrMap("FTVMAlterOpLayout")) {
static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
if (ref_call->op.as<OpNode>()) {
Op op = Downcast<Op>(ref_call->op);
if (falter_layout.count(op) && !finfer_layout.count(op)) {
return memorizer->CallWithNewLayouts(ref_call, normal_new_args);
}
}
}

will assume this OP accepts any layout. In the case of Dense, it only accepts 2D data tensor, and when the producer for data tensor changes its layout, we need an additional layout transform to convert it back, which is not handled in L310-L320.


Expand Down
37 changes: 27 additions & 10 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
}

// old_in, new_in = state[inputs]
Array<Layout> old_in, old_out, new_in, new_out, new_in2;
// naming rule:
// old_in, new_in: the input layouts given by downstream node.
// old_in2, new_in2: the input layouts inferred by the current node.
Array<Layout> old_in, old_in2, old_out, new_in, new_out, new_in2;
for (auto inp : inputs) {
old_in.push_back(inp->old_layout);
new_in.push_back(inp->new_layout);
Expand All @@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
InferCorrectLayoutOutput infer_out;
std::tie(infer_out, success) =
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, types);
old_in = infer_out->input_layouts;
old_in2 = infer_out->input_layouts;
old_out = infer_out->output_layouts;
if (!success) {
return Expr(nullptr);
}
ICHECK_EQ(old_in.size(), new_in.size());
ICHECK_EQ(old_in2.size(), new_in.size());

// if new_in == 'undef': new_in = old_in
for (size_t i = 0; i < new_in.size(); ++i) {
if (!new_in[i].defined()) {
new_in.Set(i, old_in[i]);
Array<Layout> new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts
// if new_in_tmp == 'undef': new_in_tmp = old_in2
for (size_t i = 0; i < new_in_tmp.size(); ++i) {
if (!new_in_tmp[i].defined()) {
new_in_tmp.Set(i, old_in2[i]);
}
}

Expand All @@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
success = false;
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types);
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in_tmp, old_in2, types);
new_in2 = infer_out->input_layouts;
new_out = infer_out->output_layouts;
if (!success) {
Expand All @@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
ICHECK_EQ(new_in.size(), new_in2.size())
<< "The number of input nodes should keep the same during alter_op_layout";

auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2,
const Layout& new_in, const Layout& new_in2) {
if (old_in2.Equals(old_in)) { // the two transforms can be fused to one
arg_item = memorizer.Transform(arg_item, new_in, new_in2);
} else {
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in);
arg_item = memorizer.Transform(arg_item, old_in2, new_in2);
Copy link
Member

@masahi masahi Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this code path, old_in != old_in2. So after the transform at L383, how is it possible that we can apply another transform with old_in2 as the src layout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I thought that old_in and old_in2 should be isomorphic, i.e., having the same structure (rank and subcoordinate factors, etc.) and only differing in how each axis is named (e.g., NW vs NC), given that they are describing the same tensor's layout. In this case, the transform can be applied. A concrete example: new_in=NW8w, old_in=NW, old_in2=NC, new_in2=NC16c, we will apply NW8w->NW and NC->NC16c, which is valid since layout_transform will work as long as the layout structure match the tensor shape. The net outcome is equivalent to a single transform NC8c->NC16c.

However, I just hit a bizare case in BroadcastInferLayout that does not give isomorphic layouts:

} else {
// Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
// while transforming the layout, we expand dims to make C go to NHWC, and then use the
// modified layout of the first operator to call the layout transform. E.g.
// a in NCHWC16c
// b in C
// b = expand_dims(b) from C -> NHWC
// b = layout_transform(b) from NHWC -> NCHW16c
// add(a, b)
layouts.Set(small_idx, ret);
}

This code path may expand the tensor's layout and assign old_in2 to something with larger rank. For example, if the op is a+b, and originally a's layout is NCHW and b is W, then its consumer (the broadcast node) will infer old_in2=NCHW for b. Now W and NCHW are not really isomorphic... I'm working on a fix, but this does sound pretty werid for me when you say a tensor with rank 1 is inferred with a layout with rank 4....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI many of us are aware of the messy situation our layout convert pass is in. I believe an entire rewrite is desired at some point. I'd love to have your thoughts on this in the discuss forum etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for having me in your discussion. I agree on an entire rewrite. Ping me in discuss forum when you some day decide to do it :-), and I'd love to participate.

}
return arg_item;
};

// if (new_in != new_in2): insert transform (new_in -> new_in2)
Array<Expr> transformed_args;
size_t pt = 0;
Expand All @@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
Array<Expr> transformed_tuple_arg;
transformed_tuple_arg.reserve(tuple_arg->fields.size());
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
transformed_tuple_arg.push_back(
transform_layout(arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
transformed_args.push_back(
transform_layout(arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
pt++;
}
}
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,5 +1471,18 @@ def test_conv2d_reduce_channels():
relay.build(mod, params=params, target="llvm")


def test_axis_semantic_change():
x = relay.var("x", shape=(1, 1, 24, 48))
w1 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
w2 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
y = relay.nn.conv2d(x, w1, kernel_size=(1, 1), padding=(0, 0), channels=1)
y = relay.transpose(y, (0, 1, 3, 2))
z = relay.nn.conv2d(y, w2, kernel_size=(1, 1), padding=(0, 0), channels=1)
func = relay.Function([x], z)
mod = tvm.IRModule.from_expr(func)
with tvm.transform.PassContext(opt_level=3):
relay.build(mod, target="llvm")


if __name__ == "__main__":
pytest.main([__file__])