-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LayoutRewriter #10118
Fix LayoutRewriter #10118
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// old_in, new_in = state[inputs] | ||||||||||||||||||||||||
Array<Layout> old_in, old_out, new_in, new_out, new_in2; | ||||||||||||||||||||||||
// naming rule: | ||||||||||||||||||||||||
// old_in, new_in: the input layouts given by downstream node. | ||||||||||||||||||||||||
// old_in2, new_in2: the input layouts inferred by the current node. | ||||||||||||||||||||||||
Array<Layout> old_in, old_in2, old_out, new_in, new_out, new_in2; | ||||||||||||||||||||||||
for (auto inp : inputs) { | ||||||||||||||||||||||||
old_in.push_back(inp->old_layout); | ||||||||||||||||||||||||
new_in.push_back(inp->new_layout); | ||||||||||||||||||||||||
|
@@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj | |||||||||||||||||||||||
InferCorrectLayoutOutput infer_out; | ||||||||||||||||||||||||
std::tie(infer_out, success) = | ||||||||||||||||||||||||
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, types); | ||||||||||||||||||||||||
old_in = infer_out->input_layouts; | ||||||||||||||||||||||||
old_in2 = infer_out->input_layouts; | ||||||||||||||||||||||||
old_out = infer_out->output_layouts; | ||||||||||||||||||||||||
if (!success) { | ||||||||||||||||||||||||
return Expr(nullptr); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
ICHECK_EQ(old_in.size(), new_in.size()); | ||||||||||||||||||||||||
ICHECK_EQ(old_in2.size(), new_in.size()); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// if new_in == 'undef': new_in = old_in | ||||||||||||||||||||||||
for (size_t i = 0; i < new_in.size(); ++i) { | ||||||||||||||||||||||||
if (!new_in[i].defined()) { | ||||||||||||||||||||||||
new_in.Set(i, old_in[i]); | ||||||||||||||||||||||||
Array<Layout> new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts | ||||||||||||||||||||||||
// if new_in_tmp == 'undef': new_in_tmp = old_in2 | ||||||||||||||||||||||||
for (size_t i = 0; i < new_in_tmp.size(); ++i) { | ||||||||||||||||||||||||
if (!new_in_tmp[i].defined()) { | ||||||||||||||||||||||||
new_in_tmp.Set(i, old_in2[i]); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj | |||||||||||||||||||||||
// new_in2, new_out = op.infer(new_in) | ||||||||||||||||||||||||
if (new_call->op->IsInstance<OpNode>()) { | ||||||||||||||||||||||||
success = false; | ||||||||||||||||||||||||
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types); | ||||||||||||||||||||||||
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in_tmp, old_in2, types); | ||||||||||||||||||||||||
new_in2 = infer_out->input_layouts; | ||||||||||||||||||||||||
new_out = infer_out->output_layouts; | ||||||||||||||||||||||||
if (!success) { | ||||||||||||||||||||||||
|
@@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj | |||||||||||||||||||||||
ICHECK_EQ(new_in.size(), new_in2.size()) | ||||||||||||||||||||||||
<< "The number of input nodes should keep the same during alter_op_layout"; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2, | ||||||||||||||||||||||||
const Layout& new_in, const Layout& new_in2) { | ||||||||||||||||||||||||
if (old_in2.Equals(old_in)) { // the two transforms can be fused to one | ||||||||||||||||||||||||
arg_item = memorizer.Transform(arg_item, new_in, new_in2); | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in); | ||||||||||||||||||||||||
arg_item = memorizer.Transform(arg_item, old_in2, new_in2); | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this code path, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. I thought that However, I just hit a bizare case in BroadcastInferLayout that does not give isomorphic layouts: tvm/src/relay/transforms/infer_layout_utils.h Lines 224 to 234 in 6a274af
This code path may expand the tensor's layout and assign old_in2 to something with larger rank. For example, if the op is a+b , and originally a 's layout is NCHW and b is W , then its consumer (the broadcast node) will infer old_in2=NCHW for b . Now W and NCHW are not really isomorphic... I'm working on a fix, but this does sound pretty werid for me when you say a tensor with rank 1 is inferred with a layout with rank 4....
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI many of us are aware of the messy situation our layout convert pass is in. I believe an entire rewrite is desired at some point. I'd love to have your thoughts on this in the discuss forum etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for having me in your discussion. I agree on an entire rewrite. Ping me in discuss forum when you some day decide to do it :-), and I'd love to participate. |
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
return arg_item; | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// if (new_in != new_in2): insert transform (new_in -> new_in2) | ||||||||||||||||||||||||
Array<Expr> transformed_args; | ||||||||||||||||||||||||
size_t pt = 0; | ||||||||||||||||||||||||
|
@@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj | |||||||||||||||||||||||
Array<Expr> transformed_tuple_arg; | ||||||||||||||||||||||||
transformed_tuple_arg.reserve(tuple_arg->fields.size()); | ||||||||||||||||||||||||
for (auto arg_item : tuple_arg->fields) { | ||||||||||||||||||||||||
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); | ||||||||||||||||||||||||
transformed_tuple_arg.push_back( | ||||||||||||||||||||||||
transform_layout(arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt])); | ||||||||||||||||||||||||
pt++; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg)); | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); | ||||||||||||||||||||||||
transformed_args.push_back( | ||||||||||||||||||||||||
transform_layout(arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt])); | ||||||||||||||||||||||||
pt++; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need
DenseInferCorrectLayout
andDensePackInferCorrectLayout
now? I added these functions to workaround alter layout issues, but that might not be necessary anymore. Can you try remove them and see what happens?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed through your change in 66ac470, and I think it is probably still needed. Mainly because when FTVMAlterOpLayout is defined but FInterCorrectLayout is not, the current code logic in
tvm/src/relay/transforms/transform_layout.h
Lines 310 to 320 in 6a274af
will assume this OP accepts any layout. In the case of Dense, it only accepts 2D
data
tensor, and when the producer fordata
tensor changes its layout, we need an additional layout transform to convert it back, which is not handled in L310-L320.