Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lift_if_then_else pass #3865

Merged
merged 8 commits into from
Oct 18, 2019
Merged

Add lift_if_then_else pass #3865

merged 8 commits into from
Oct 18, 2019

Conversation

kevinthesun
Copy link
Contributor

Since nvcc cannot do loop invariant optimization in some cases: https://discuss.tvm.ai/t/expr-simplifier-for-tvm-var/3669/10, we need an extra pass to detect loop invariant if statement. This pass is not used right now. Later it will be useful for symbolic shape compilation on cuda.

@tqchen @wweic

@ajtulloch
Copy link
Contributor

@hlu1 could this help with Concat codegen on x86?

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will revisit GenerateInternalData later.
I suggest to check p != nullptr each time we do p = s.as<xxx>, so that at least we know where the problem is when sth goes wrong, instead of a silent segfault.

bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
std::vector<size_t> if_hash_list;

PostOrderVisit(for_stmt.as<For>()->body, [&](const NodeRef& node) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to check for_stmt.as<For>() first to avoid segfault when miss-used. or how about take For* as the argument?


PostOrderVisit(for_stmt.as<For>()->body, [&](const NodeRef& node) {
if (node.as<IfThenElse>()) {
if_hash_list.push_back(node.hash());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safer to use node.get()? I think you need strict equality instead of hash equality.


PostOrderVisit(parent_for_stmt.as<For>()->body, [&](const NodeRef& node) {
if (node.as<For>()) {
for_hash_list.push_back(node.hash());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue here, as<For> and node.hash()

}

// Generate internal data structures for lifter.
void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we separate the function to smaller units and make the function name more meaningful?

@anijain2305
Copy link
Contributor

Should we use hoist instead of lift?

@kevinthesun
Copy link
Contributor Author

Should we use hoist instead of lift?

Renamed.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Please ignore my comments if they do not make sense. I have not worked thoroughly with TVM IR, so many of my comments might not by very useful.

I have some other top-level comments. Is it possible to organize the functions in 2 step process - DetectLoopInvariantStmt and Hoist. We can restrict ourselves to if stmt for now, but in future we can hoist other stmts as well if need be.

}

Stmt HoistIfThenElse(Stmt stmt) {
return IfThenElseLHoist().VisitAndMutate(stmt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does 'L' stand for in IfThenElseLHoist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo.

void IfThenElseLHoist::SelectCandidates(const Stmt& stmt) {
PostOrderVisit(stmt, [&](const NodeRef& node){
const For* for_node = node.as<For>();
if (for_node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - How about we return if it is not a for loop. Might increase readability of the code.

if (for_node == nullptr) return;

std::queue<Stmt> tracker;
tracker.push(for_node->body);
Stmt for_stmt = Downcast<Stmt, NodeRef>(node);
for2if_map_.insert({for_stmt.get(), std::vector<Stmt>()});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to understand. Seems like for_stmt.get() will give you the Node type, so why not directly use node as the key here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key type of for2if_map_ is const Node*. Did you mean the value type?

tracker.pop();
if (head->is_type<For>()) {
for (const auto& if_stmt : for2if_map_.at(head.get())) {
for2if_map_[for_stmt.get()].push_back(if_stmt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand. So here, if there is a child for loop whose if stmts have been generated, this portion will copy paste the if stms for the parent for loop as well. Correct?

Copy link
Contributor Author

@kevinthesun kevinthesun Sep 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's why nodes are visited in post order.

size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt);
Stmt HoistIf(const Stmt& if_stmt);

HoistMap if2for_map_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief for each member will be really helpful :)

continue;
}
}
ordered_for_list_.emplace_back(Downcast<Stmt, NodeRef>(node));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand. Most of the other bookkeeping has been done for Node* type, but this one is for Stmt. Any reason we want to do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When generating if2for_map_, we need to push a Stmt object into each value vector. Here we directly store Stmt object so that we don't need to use For::make to build a Stmt again.

top_for_var_map_.insert({for_node->loop_var.get(), if_list});
for (const Stmt& if_stmt : if_list) {
const Node* if_node = if_stmt.get();
if (!if2for_map_.count(if_node)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the count check. operator[] will insert the mapped value when the key is missing.


// Map of all For nodes to all child IfThenElse nodes.
HoistMap for2if_map_;
// Map of all IfThenElse nodes to all For nodes which are loop invariant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any order in the vector? from outer most to inner most for loops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. We keep the order in ordered_for_list_.

// With this function we only need to visit and mutate top level For node
// in the main VisitAndMutate function.
Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
std::vector<const Node*> for_node_list;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use a Node*?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since noderef.get() returns const Node*, for_node_list aligns with the type.

tests/python/unittest/test_pass_hoist_if.py Show resolved Hide resolved
// With this function we only need to visit and mutate top level For node
// in the main VisitAndMutate function.
Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
std::vector<const Node*> for_node_list;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need a vector, just keep a Node* should work, right?

@tqchen
Copy link
Member

tqchen commented Oct 10, 2019

ping @kevinthesun @yzhliu please followup on this pr

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Oct 11, 2019

@yzhliu @wweic @anijain2305 Just rebased with master. Can you take another look?

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wweic @anijain2305 Could you double check?

@wweic
Copy link
Contributor

wweic commented Oct 16, 2019

I think for_node_list in function update_for can just be a Node*, no need to use vector. otherwise lGTM.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

It is little uncomfortable to see that we have to so much book-keeping to perform a transformation. Analysis doesn't seem that easy with TVM IR. Maybe in future, we can consider this pass as a testbed when we improve the TVM IR to simplify analysis.

Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiics zhiics merged commit 687d4a8 into apache:master Oct 18, 2019
@zhiics
Copy link
Member

zhiics commented Oct 18, 2019

Thanks everyone. This is now merged.

@kevinthesun kevinthesun deleted the LiftIf branch October 18, 2019 22:25
kevinthesun added a commit to kevinthesun/tvm that referenced this pull request Oct 30, 2019
* Add LiftIfThenElse pass

* Add more comments

* Rename and refactor

* Add description for internal data structure

* Rename a test

* Minor change

* Address comments

* Improve update_for
kevinthesun added a commit to neo-ai/tvm that referenced this pull request Oct 31, 2019
* [relay][vm] Separate VM runtime with executable (apache#4100)

* [relay][vm] Separate VM runtime with executable

* Address comments

* move ctx back to vm

* make only vm related fields and methods protected

* integrate seriliaztion/deserialization to executable

* create stream

* [Relay][Frontend][TF] Add tensor array ops (apache#3798)

* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test

* Fix typo (apache#4144)

* [CI] Pin NNPack pthreadtools version (apache#4152)

* [QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2. (apache#4142)

* Add lift_if_then_else pass (apache#3865)

* Add LiftIfThenElse pass

* Add more comments

* Rename and refactor

* Add description for internal data structure

* Rename a test

* Minor change

* Address comments

* Improve update_for

* [CI] Update cpu docker (apache#4153)

* [Refactor] Rename Datatype to ADT (apache#4156)

We think it will reduce the confusion with the meaning.

https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339

* [Runtime] Enable option to use OpenMP thread pool (apache#4089)

* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (apache#4161)

* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol.

This PR removes the original node system, and make node as a subclass of Object.
This is a major refactor towards a better unified runtime object system.

List of changes in the refactor:

- We now hide data_ field, use Downcast explicitly to get a sub-class object.
- Removed the node system FFI in python.
- Removed the node C API, instead use PackedFunc for list and get attrs.
- Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type<AttrType>().
  - This change was necessary because of the new Object registration mechanism.
  - Subsequent changes to the op registrations
  - The change revealed a few previous problems that is now fixed.
- Patched up a few missing node type registration.
  - Now we will raise an error if we register object that is not registered.
- The original node.h and container.h are kept in the same location.
- Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed.
- IRFunctor now dispatches on ObjectRef.
- Update to the new type checking API: is_type, derived_from are replaced by IsInstance.
- Removed .hash member function, instead use C++ convention hasher functors.

* Address review comments

* [CI] Move golang tests to the end (apache#4164)

* Add support for quantized multiply to Relay (apache#4141)

This patch adds multiply operator for quantized tensors.
The details of the quantized multiplication are outlined
in the code.

This builds on pull request 3927 and includes the changes
Animesh mentions in the comments on that request.

Change-Id: I555715b53d0266a91d5c03dc3dfe8fc31e7ce4e1

* Fix missspelling (apache#4166)

FIX "After connecting he usb" with "After connecting the usb"

* [Relay][Pass] Count MAC for BatchMatMul (apache#4157)

* count MAC for BatchMatMul

* update doc

* [Relay][QNN] Add unit test for int8 (apache#4159)

* [bugfix][codegen] fix casting bug in llvm codegen

* update example

* retrigger ci

* check llvm version

* [relay][vm] Reuse allocated device memory (apache#4170)

* add missing gradient check to gradient pass (apache#4169)

* merge extract_from_program and extract_from_multiple_progam (apache#4173)

* [TOPI] Added support for Mali Bifrost target (apache#4047)

* [Relay][Frontend][TF] Fix Size operator (apache#4175)

* [Relay][Frontend][TF] Fix Size operator

* Uncomment tests

* [Pass] Remove dead code (apache#4177)

* [rpc] use callback func to do send & recv (apache#4147)

* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java

* fix java build

* fix min/max macro define in windows

* keep the old rpc setup for py

* add doc for CallbackChannel

* Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend. (apache#4172)

* [DOCS] Add TensorFlow frontend docs (apache#4154)

* Start to update TF frontend docs

* Add rst

* Remove markdown

* Update wording

* Resolve comments

* Revert "[Relay][QNN] Add unit test for int8 (apache#4159)" (apache#4192)

This reverts commit 6f9d028.

* [cmake][ANTLR] Support setting path to ANTLR jar (apache#4176)

* Support setting path to ANTLR jar

* Update comment

* Split adaptive_pool2d_avg into sum and div (apache#4186)

* [Documentation]Fix example code in comment of tvm.build_module.build() (apache#4195)

* Fix example code in comment of tvm.build_module.build()

* Update build_module.py

* [relay] use time_evaluator for measurement (apache#4191)

* Add parser support for SUM tflite operator (apache#4182)

* [Relay] Fix memory leak in the interpreter (apache#4155)

* save

lint

* address reviewer comment

* [TOPI] Tunable Template for Conv2D HWCN on CUDA (apache#4168)

* support conv2d HWCN in AutoTVM and Relay

* fix lint

* fix comments and unit tests

* TensorCore Support using Intrinsic (apache#4136)

* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic

* [NODE][REFACTOR] Refactor reflection system in node. (apache#4189)

* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments

* hotfix the ci (apache#4199)

* [TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI instructions. (apache#4196)

* [Relay] crossentropy_with_logits and its gradient (apache#4075)

* save

* lint

* [hotfix] missing include headers (apache#4204)

* [Relay][Training] Add checkpoint annotation for checkpointing memory optimization (apache#4146)

* add checkpoint annotation for checkpointing memory optimization

* add alpha-equivalence checkpoint test and fix gradient type issue

* fix build issues

* ignore checkpoint annotation when checking missing gradients

* refactor, fix checkpoint compute for tuple and add tests

* [Relay][Params] Add APIs for storing and retrieving parameters from individual functions. (apache#4194)

* Add support for attaching params

* Fix types

* Fix test

* [Relay][Frontend][ONNX] Add support for op Where (apache#4184)

* Add support for op Where

* Update impl version

* [VTA][Chisel] TSIM VTA Source Refactor (apache#4163)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo

* [RUNTIME] Separate runtime related contrib into runtime/contrib (apache#4207)

* Fix type var docs (apache#4208)

* [Relay] Setting Legalize opt_level to 1. (apache#4198)

* [TOPI] Fix flaky testcase for check round (apache#4211)

* [Relay][Op] Enhance Upsample Operator to support float scales   (apache#4206)

* :add scale2 for upsample

* update unit test for upsampling

* support latest upsample op for multiple frontend

* fix lint

* fix lint

* fix lint

* fix lint

* update scale description and rebase

* [Relay][Quantize] Use fixed point mulplications (apache#4160)

* Update have_int8 condition to run on compute capability 7.x devices (apache#4214)

* Optimizing autotvm task extraction speed (apache#4138)

* Optimize task extraction speed

* correct pylint errors

* Delete unused function

* remove unnecessary argument

* resolve code review comments

* corrent cpp lint errors

* remove one more graph_json return value

* fix test bugs

* [Relay] Add Python type functor and tests (apache#4209)

* Add Python type functor and tests

* Lint roller

* Fix typo in packed_func.h (apache#4219)

* Improve the lowering of Qnn Dense (apache#4213)

* [QNN] Improving Dense lowering.

* - Moving get_shape method to util
- Finalizing the test cases and the code structure for optimized dense computation.

* - Fixing cpplint.

* - Addressing review comments.

* - Renaming the variables correctly.

* - Renaming the variables correctly.

* [ARITH] Fix the rule y < x && x <= y (apache#4220)

* [PYTHON] Add __init__ to the generated grammar so that it can be installed properly (apache#4223)

* [Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (apache#4197)

* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.

* [Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (apache#4205)

* Add support for Any op

* Support ONNX frontend

* Add doc

* Add to relay docs

* Dummy change to retrigger CI

*  Update dmlc_tvm_commit_id.txt

* Merge from upstream
@roastduck
Copy link
Contributor

Hi everyone, I'm wondering if we can merge this pass into the default lowering procedure? Hoisting if statements can be very helpful for sparse applications, since LoopPartition cannot eliminate their if statements with dynamic (unknown at compile time) conditions. If there is no problem, I can make a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants