Skip to content

Commit

Permalink
Add "sparse" block attribute. (apache#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Nov 23, 2021
1 parent 0107635 commit 40dd745
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
13 changes: 9 additions & 4 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,17 @@ Definition of a scope that is a stage pipeline:
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// NOTE(Zihao): check if the block has atomic attribute.
auto&& it = block->annotations.find("atomic");
auto&& it_atomic = block->annotations.find("atomic");
bool is_atomic = false;
if (it != block->annotations.end()) {
is_atomic = ((*it).second).as<IntImmNode>()->value;
if (it_atomic != block->annotations.end()) {
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
}
if (!is_atomic) {
auto&& it_sparse = block->annotations.find("sparse");
bool is_sparse = false;
if (it_sparse != block->annotations.end()) {
is_sparse = ((*it_sparse).second).as<IntImmNode>()->value;
}
if (!is_sparse && !is_atomic) {
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,10 @@ class IndexTransformer : public StmtExprMutator {
GenerateReadWriteRegions(sp_block, &reads, &writes);

// Step 5. Create the block and block-realize
Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body),
std::move(init));
std::move(init), {}, {}, std::move(mapping));
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));

// Step 6. Create outer loops and the block binding.
Expand Down
12 changes: 9 additions & 3 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32")
C = T.match_sparse_buffer(c, (I, K), m * k, "float32")
with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]:
T.block_attr({"sparse": True})
with T.init():
C[vi, vk] = 0.0
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]
Expand All @@ -51,6 +52,7 @@ def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
C[vi * K + vk] = 0.
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
with T.block("spmm_inner"):
T.block_attr({"sparse": True})
vj = T.axis.R(NNZ, j + A_indptr[vi])
C[vi * K + vk] = C[vi * K + vk] + \
A_data[vj] * B[A_indices[vj] * K + vk]
Expand All @@ -71,6 +73,7 @@ def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
C[(vio * BLOCK_SIZE + vii) * K + vk] = 0.
for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]):
with T.block("spmm_inner"):
T.block_attr({"sparse": True})
vjo = T.axis.R(NNZB, jo + A_indptr[vio])
C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[(
vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk]
Expand All @@ -85,6 +88,7 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32")
for i, j, k in T.grid(M, NNZ_COLS, K):
with T.block("spmm"):
T.block_attr({"sparse": True})
vi, vj, vk = T.axis.remap("SRS", [i, j, k])
with T.init():
C[vi * K + vk] = 0.
Expand All @@ -102,6 +106,7 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
C_indices = T.match_buffer(indices, (NNZ,), "int32")
for ij, k in T.grid(NNZ, K):
with T.block("sddmm"):
T.block_attr({"sparse": True})
vij, vk = T.axis.remap("SR", [ij, k])
T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]])
T.writes([C_data[vij]])
Expand Down Expand Up @@ -262,10 +267,10 @@ def test_sddmm():
)
blk = sch.get_block("sddmm")
ij, k = sch.get_loops(blk)
#sch.decompose_reduction(blk, ij)
# TODO(zihao): fix the behavior in the future.
# sch.decompose_reduction(blk, ij)
sch.bind(ij, "blockIdx.x")
ko, ki = sch.split(k, [None, 1])
sch.bind(ki, "threadIdx.x")
sch.bind(k, "threadIdx.x")

# convert numpy tensor to tvm ndarray
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
Expand All @@ -276,6 +281,7 @@ def test_sddmm():

# build function
f = tvm.build(sch.mod['main'], target="cuda")
# print(f.imported_modules[0].get_source())
f(X_nd, Y_nd, C_data, C_indptr, C_indices)

# assertion
Expand Down
5 changes: 5 additions & 0 deletions tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def lowered_csrmm(
for v_vi in T.serial(0, n):
for v_vj, v_vk in T.grid(J_indptr[v_vi + 1] - J_indptr[v_vi], k):
with T.block("csrmm"):
T.block_attr({"sparse": True})
vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk])
T.reads(
[
Expand Down Expand Up @@ -125,6 +126,7 @@ def lowered_csr_reduce(
for v_vi in T.serial(0, n):
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
with T.block("csr_reduce"):
T.block_attr({"sparse": True})
vi, vj = T.axis.remap("SR", [v_vi, v_vj])
T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]])
T.writes([B_data[0:n]])
Expand Down Expand Up @@ -190,6 +192,7 @@ def lowered_bsrmm(
J_indptr[v_vi + 1] - J_indptr[v_vi], blk, blk, feat_size
):
with T.block("bsrmm"):
T.block_attr({"sparse": True})
vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf])
T.reads(
[
Expand Down Expand Up @@ -263,6 +266,7 @@ def lowered_ellpack_mm(
J_indices = T.match_buffer(indices, [nnz], dtype="int32")
for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size):
with T.block("bsrmm"):
T.block_attr({"sparse": True})
vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf])
T.reads(
[
Expand Down Expand Up @@ -359,6 +363,7 @@ def lowered_csr_element_wise(
for v_vi in T.serial(0, m):
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
with T.block("csr_element_wise"):
T.block_attr({"sparse": True})
vi, vj = T.axis.remap("SS", [v_vi, v_vj])
T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]])
T.writes([B_data[0:nnz]])
Expand Down

0 comments on commit 40dd745

Please sign in to comment.