Skip to content

Commit

Permalink
Change the order of generated blocks for block isolation. (apache#35)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd
  • Loading branch information
yzh119 authored and MasterJH5574 committed Dec 22, 2021
1 parent 437d88c commit 7572d3a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 82 deletions.
149 changes: 74 additions & 75 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/tir/transform.h>

#include <set>
#include <stack>
#include <utility>

#include "../../support/utils.h"
Expand Down Expand Up @@ -521,84 +520,84 @@ class IndexTransformer : public StmtExprMutator {
var_map[sp_iter_var->var.get()] = loop_var;
}

// Step 4. Collet block iters and iter bindings.
std::set<const AxisNode*> in_stack;
// Step 4. Collect block iters and iter bindings.
/* Whether the axis appears in the stack. */
std::unordered_set<const AxisNode*> in_stack;
/* A stack that stores block itervars in each block. */
std::stack<Array<IterVar>> block_iters_st;
std::vector<Array<IterVar>> block_iters_st;
/* A stack that stores itervar bindings in each block. */
std::stack<Array<PrimExpr>> iter_bindings_st;
std::vector<Array<PrimExpr>> iter_bindings_st;
/* A stack that stores generated loop vars in each block. */
std::stack<Array<Var>> loop_vars_st;
std::vector<Array<Var>> loop_vars_st;
/* A stack that stores whether to place init block in each block. */
std::stack<bool> place_init_st;
std::vector<bool> place_init_st;
/* An indicator that records whether init block has been set. */
bool init_set = false;
do {
/* Block itervars of current block. */
Array<IterVar> block_iters;
/* Itervar bindings of current block. */
Array<PrimExpr> iter_bindings;
/* Axis names of current block. */
Array<Axis> blk_axes;
/* Generated loop vars of current block. */
Array<Var> loop_vars;
/* An indicator that records whether there is reduction axis in current block. */
bool has_reduction_var = false;
for (int i = 0; i < n_iter; ++i) {
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
Axis axis = sp_it_var->axis;

/* Add itervar to current block when
* - it's not used yet (not in stack) and
* - it's parent axis was used in outer blocks or
* - it's an iterator to a fixed axis.
*/
auto parent = axis->GetParentAxis();
bool emit_iter_var = true;
if (in_stack.find(axis.get()) !=
in_stack.end()) { // the iter var has already been emitted.
emit_iter_var = false;
/* Block itervars of current block. */
Array<IterVar> block_iters;
/* Itervar bindings of current block. */
Array<PrimExpr> iter_bindings;
/* Generated loop vars of current block. */
Array<Var> loop_vars;
/* Whether the axis appears in the cuurent block. */
std::unordered_set<const AxisNode*> in_block;
/* An indicator that records whether there is reduction axis in current block. */
bool has_reduction_var = false;

auto UpdateStack = [&]() {
block_iters_st.emplace_back(std::move(block_iters));
iter_bindings_st.emplace_back(std::move(iter_bindings));
loop_vars_st.emplace_back(std::move(loop_vars));
if (init_set) {
place_init_st.emplace_back(false);
} else {
place_init_st.emplace_back(has_reduction_var);
init_set |= has_reduction_var;
}
};

for (int i = 0; i < n_iter; ++i) {
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
Axis axis = sp_it_var->axis;
auto parent = axis->GetParentAxis();
bool create_new_blk = false;
bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
if (!is_fixed_axis && parent.defined()) {
const AxisNode* parent_node = parent.value().get();
if (in_block.find(parent_node) != in_block.end()) {
/* parent node is in the current block, need to create new block. */
create_new_blk = true;
} else if (in_stack.find(parent_node) != in_stack.end()) {
/* parent node is in the previous blocks in the stack, no need to create new block. */
create_new_blk = false;
} else {
if (parent.defined()) { // has parent
if (in_stack.find(parent.value().get()) == in_stack.end()) { // parent not emitted yet
if (axis->kind() == AxisKind::kDenseVariable ||
axis->kind() == AxisKind::kSparseVariable) { // is not fixed axis.
emit_iter_var = false;
}
}
}
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block.";
}
// LOG(INFO) << axis->name << " " << (parent.defined() ? parent.value()->name : "no-parent")
// << " " << emit_iter_var;
if (emit_iter_var) {
loop_vars.push_back(all_loop_vars[i]);
blk_axes.push_back(axis);
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
iter_bindings.push_back(all_loop_vars[i]);
has_reduction_var |= sp_it_var->is_reduction;
}
if (create_new_blk) {
/* update in stack set. */
for (const AxisNode* node : in_block) {
in_stack.insert(node);
}
/* Update stack. */
UpdateStack();
/* Reset block states. */
loop_vars = {};
block_iters = {};
iter_bindings = {};
has_reduction_var = false;
in_block.clear();
}

/* Tag axes in current block as "in-stack". */
for (const Axis&& axis : blk_axes) {
in_stack.insert(axis.get());
}
loop_vars.push_back(all_loop_vars[i]);
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
iter_bindings.push_back(all_loop_vars[i]);
has_reduction_var |= sp_it_var->is_reduction;
in_block.insert(axis.get());
}

/* Update stack. */
if (!block_iters.empty()) {
block_iters_st.push(std::move(block_iters));
iter_bindings_st.push(std::move(iter_bindings));
loop_vars_st.push(std::move(loop_vars));
if (init_set) {
place_init_st.push(false);
} else {
place_init_st.push(has_reduction_var);
init_set |= has_reduction_var;
}
} else {
break;
}
} while (true);
// Update the last block.
UpdateStack();

// Step 5. Generate the read-region and write-retion of the block.
Array<BufferRegion> reads{};
Expand All @@ -608,14 +607,14 @@ class IndexTransformer : public StmtExprMutator {
// Step 6. Generate nested blocks and loops from innermost to outermost.
int blk_counter = 0;
while (!block_iters_st.empty()) {
Array<IterVar> block_iters = std::move(block_iters_st.top());
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top());
Array<Var> loop_vars = std::move(loop_vars_st.top());
bool place_init = place_init_st.top();
block_iters_st.pop();
iter_bindings_st.pop();
loop_vars_st.pop();
place_init_st.pop();
Array<IterVar> block_iters = std::move(block_iters_st.back());
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.back());
Array<Var> loop_vars = std::move(loop_vars_st.back());
bool place_init = place_init_st.back();
block_iters_st.pop_back();
iter_bindings_st.pop_back();
loop_vars_st.pop_back();
place_init_st.pop_back();

Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
Expand Down
10 changes: 3 additions & 7 deletions tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def csrmm(
A = T.match_sparse_buffer(a, (I, J), "float32")
B = T.match_sparse_buffer(b, (T.dense(J), K), "float32")
C = T.match_sparse_buffer(c, (I, K), "float32")
with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]:
with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]:
with T.init():
C[vi, vk] = 0.0
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]
Expand Down Expand Up @@ -180,12 +180,12 @@ def bsrmm(
B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32")
C = T.match_sparse_buffer(c, (I, BI, F), "float32")

with T.iter([I, J, BI, BJ, F], "SRSRS", "bsrmm") as [
with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [
vi,
vj,
vbi,
vbj,
vf,
vj,
]:
with T.init():
C[vi, vbi, vf] = 0.0
Expand Down Expand Up @@ -314,7 +314,6 @@ def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices
def test_csrmm():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())
tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)

A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr")
Expand All @@ -338,14 +337,12 @@ def test_csrmm():
def test_csrmm_dense_iter():
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)


def test_segment_reduce():
mod = tvm.IRModule.from_expr(segment_reduce)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())


def test_csr_reduce():
Expand Down Expand Up @@ -412,7 +409,6 @@ def test_bsrmm():
def test_ellpack_mm():
mod = tvm.IRModule.from_expr(ellpack_mm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())
tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True)

nnz_cols = 4
Expand Down

0 comments on commit 7572d3a

Please sign in to comment.