diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ac71fdce47b..03078b8be41f 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -377,6 +377,13 @@ Stmt LowerStorageAccessInfo(Stmt stmt); */ Stmt DecorateDeviceScope(Stmt stmt); +/*! + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. + * \return Transformed stmt. + */ +Stmt HoistIfThenElse(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 25cd5838385f..d2352496c2b4 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -160,5 +160,6 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); +REGISTER_PASS(HoistIfThenElse); } // namespace ir } // namespace tvm diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc new file mode 100644 index 000000000000..bbdb609e9a08 --- /dev/null +++ b/src/pass/hoist_if_then_else.cc @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file hoist_if_then_else.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../arithmetic/int_set.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using HoistMap = std::unordered_map>; +using VarMap = std::unordered_map>; + +/* + * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. + * For example, given the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. + * Then we hoist IfThenElse stmt by one For stmt each step: + * + * Step 1: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Step 2: + * for (i = 0; i < 3; i++) + * if (likely(i*2 < 4)) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * In this pass, we only continue detecting possible hoisting chance when visiting For, + * IfThenElse or AttrStmt Node. For example, for the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Only the For with k variable will be considered and the resulting stmt would be: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following + * block won't be optimized: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * if (likely(j > 2)) + * A[i+j+k] = B[i+j+k] + * + */ +class IfThenElseHoist { + public: + Stmt VisitAndMutate(const Stmt& stmt) { + SelectCandidates(stmt); + LocateTopFor(); + return PostOrderMutate(stmt); + } + + private: + void SelectCandidates(const Stmt& stmt); + void LocateTopFor(); + Stmt PostOrderMutate(const Stmt& stmt); + size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); + Stmt HoistIf(const Stmt& if_stmt); + + // Map of all For nodes to all child IfThenElse nodes. + HoistMap for2if_map_; + // Map of all IfThenElse nodes to all For nodes which are loop invariant. + HoistMap if2for_map_; + // Map of highest loop invariant For to child IfThenElse. + HoistMap top_for_var_map_; + // Map of original For to list of update For nodes. + HoistMap for_tracking_map_; + // Map of all IfThenElse nodes to condition variable nodes. + VarMap cond_var_map_; + // List of For nodes added in post order DFS visiting. + std::vector ordered_for_list_; +}; + +// Check whether a given IfThenElse stmt is the first one appearing +// in a For stmt. +bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { + std::vector if_node_list; + const For* for_node = for_stmt.as(); + CHECK(for_node); + CHECK(if_stmt.as()); + + PostOrderVisit(for_node->body, [&](const NodeRef& node) { + if (node.as()) { + if_node_list.push_back(node.get()); + } + }); + return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back(); +} + +// Update upper level For node when current For node is modified. +// With this function we only need to visit and mutate top level For node +// in the main VisitAndMutate function. +Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { + const Node* top_for_node; + const For* parent_for_node = parent_for_stmt.as(); + CHECK(parent_for_node); + CHECK(new_if_stmt.as()); + + PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { + if (node.as()) { + top_for_node = node.get(); + } + }); + + PackedFunc replace_target_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); + + return IRTransform(parent_for_stmt, nullptr, replace_target_for, + {Expr("For")}); +} + +// Remove IfThenElse node from a For node. +// A pair of For nodes will be generated. +std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { + Stmt then_for; + Stmt else_for; + CHECK(if_stmt.as()); + + PackedFunc replace_then_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); + + PackedFunc replace_else_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); + + then_for = IRTransform(for_stmt, nullptr, replace_then_case, + {Expr("IfThenElse")}); + if (if_stmt.as()->else_case) { + else_for = IRTransform(for_stmt, nullptr, replace_else_case, + {Expr("IfThenElse")}); + } + + return std::make_pair(then_for, else_for); +} + +// Locate all For nodes and capture child IfThenElse nodes. +void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { + PostOrderVisit(stmt, [&](const NodeRef& node){ + const For* for_node = node.as(); + if (!for_node) return; + + std::queue tracker; + tracker.push(for_node->body); + Stmt for_stmt = Downcast(node); + for2if_map_.insert({for_stmt.get(), std::vector()}); + while (!tracker.empty()) { + Stmt head = tracker.front(); + tracker.pop(); + if (head->is_type()) { + for (const auto& if_stmt : for2if_map_.at(head.get())) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } + } else if (head->is_type()) { + const AttrStmt* attr_node = head.as(); + tracker.push(attr_node->body); + } else if (head->is_type()) { + for2if_map_[for_stmt.get()].push_back(head); + const IfThenElse* if_node = head.as(); + tracker.push(if_node->then_case); + if (if_node->else_case) { + tracker.push(if_node->else_case); + } + + // Record condition variables. + if (!cond_var_map_.count(head.get())) { + std::unordered_set new_var_set; + cond_var_map_.insert({head.get(), new_var_set}); + PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + if (cond_node.as()) { + cond_var_map_[head.get()].insert(cond_node.get()); + } + }); + } + } else { + continue; + } + } + ordered_for_list_.emplace_back(Downcast(node)); + }); +} + +// For each IfThenElse node, find the highest For node which +// meets loop invariant condition. +void IfThenElseHoist::LocateTopFor() { + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; + + // Create IfThenElse -> For map. + for (const Stmt& for_stmt : ordered_for_list_) { + std::vector if_list = for2if_map_[for_stmt.get()]; + const For* for_node = for_stmt.as(); + CHECK(for_node); + top_for_var_map_.insert({for_node->loop_var.get(), if_list}); + for (const Stmt& if_stmt : if_list) { + const Node* if_node = if_stmt.get(); + if2for_map_[if_node].push_back(for_stmt); + } + } + + // Locate the highest For node which is loop invariant. + for (const auto& item : if2for_map_) { + Stmt top_for; + const Node* if_stmt = item.first; + std::vector for_list = item.second; + for (size_t i = 0; i < for_list.size(); ++i) { + const Stmt& for_stmt = for_list.at(i); + const For* for_node = for_stmt.as(); + CHECK(for_node); + std::vector new_for_list{for_stmt}; + for_tracking_map_.insert({for_stmt.get(), new_for_list}); + if (cond_var_map_[if_stmt] + .count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), + for_list.begin() + i); + if2for_map_[if_stmt] = updated_for_list; + break; + } else { + top_for = for_stmt; + } + } + if (top_for.as()) { + if_position_map.insert({if_stmt, top_for}); + } + } + + for (const auto& item : if_position_map) { + top_for_var_set.insert(item.second.as()->loop_var.get()); + } + + std::vector removed_for_var_list; + for (const auto& item : top_for_var_map_) { + const Node* top_for_var = item.first; + std::vector if_list = item.second; + if (!top_for_var_set.count(top_for_var)) { + removed_for_var_list.push_back(top_for_var); + } else { + std::vector actual_if_list; + for (const Stmt& if_stmt : if_list) { + if (if_position_map.count(if_stmt.get())) { + actual_if_list.push_back(if_stmt); + } + } + top_for_var_map_[top_for_var] = actual_if_list; + } + } + for (const Node* top_for_var : removed_for_var_list) { + top_for_var_map_.erase(top_for_var); + } +} + +// When we try to mutate a For node, some child For nodes can have already +// been mutated. This function is to get the updated For node and further +// hoisting can be done based on this new node. +// We keep all For nodes tracing in for_tracking_map_. When we get a +// hoisted IfThenElse, we match it with tracing For nodes to pick +// the updated one. +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, + const Stmt& if_stmt) { + std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; + size_t updated_for_idx = 0; + for (size_t i = 0; i < tracked_for_list.size(); ++i) { + const Stmt& current_for = + tracked_for_list.at(tracked_for_list.size() - 1 - i); + if (is_first_if(current_for, if_stmt)) { + updated_for_idx = tracked_for_list.size() - 1 - i; + break; + } + } + return updated_for_idx; +} + +// Hoist an IfThenElse node as high as possible. +// This function iterates on all candidate For nodes. For each For node, +// it first removes IfThenElse nodes. Then it generates a new IfThenElse +// node using mutated For nodes. +Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { + Stmt new_if = if_stmt; + + for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { + const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); + size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); + const Stmt& updated_for_node = + for_tracking_map_[for_stmt.get()].at(updated_for_idx); + auto generated_for_pair = RemoveIf(updated_for_node, new_if); + const Stmt& then_for = generated_for_pair.first; + const Stmt& else_for = generated_for_pair.second;; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; + + if (else_for.get()) { + for_tracking_map_[for_stmt.get()].push_back(else_for); + } + + const IfThenElse* new_if_node = new_if.as(); + CHECK(new_if_node); + new_if = IfThenElse::make(new_if_node->condition, then_for, else_for); + if (i < if2for_map_[if_stmt.get()].size() - 1) { + const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); + const Stmt& actual_next_for = + for_tracking_map_[original_next_for.get()].at(updated_for_idx); + Stmt update_for_stmt = update_for(actual_next_for, new_if); + + for_tracking_map_[original_next_for.get()]. + at(updated_for_idx) = update_for_stmt; + } + } + return new_if; +} + +// Mutate For nodes in post order DFS manner. +Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { + PackedFunc replace_top_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + const For* for_node = current_for.as(); + if (!for_node) return; + + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : + top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = + new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = + IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, + next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); +} + +Stmt HoistIfThenElse(Stmt stmt) { + return IfThenElseHoist().VisitAndMutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py new file mode 100644 index 000000000000..4a28cf6b318a --- /dev/null +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +var_list = [] + +def verify_structure(stmt, expected_struct): + node_dict = {} + struct = {} + def _extract_vars(op): + global var_list + if isinstance(op, tvm.expr.Var): + var_list.append(op.name) + + def _visit(op): + key = op + if isinstance(op, tvm.stmt.IfThenElse): + global var_list + tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + var_list.clear() + elif isinstance(op, tvm.stmt.For): + val = [(op.body,), ("For", op.loop_var.name)] + elif isinstance(op, tvm.stmt.AttrStmt): + val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + else: + return + node_dict[key] = val + + tvm.ir_pass.PostOrderVisit(stmt, _visit) + for key, val in node_dict.items(): + struct[val[1]] = tuple(node_dict[child][1] if child in node_dict + else None for child in val[0]) + + assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ + % (expected_struct, struct) + var_list.clear() + +def test_basic(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_no_else(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_attr_stmt(): + ib = tvm.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), + ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), + ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), + ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + verify_structure(new_stmt, expected_struct) + +def test_nested_for(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), + ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_if_block(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + n = tvm.var("n") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + with ib.if_scope(j <5): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 15, "k") as k: + with ib.if_scope(n >= 3): + data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), + ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), + ('IfThenElse', ('n',)): (('For', 'j'), None)} + verify_structure(new_stmt, expected_struct) + + +if __name__ == "__main__": + test_basic() + test_no_else() + test_attr_stmt() + test_nested_for() + test_if_block()