From 5d049abd703aec355c52a1bdc08527c6b929acd1 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Fri, 1 Jan 2021 06:17:14 +0800 Subject: [PATCH] [type] Support offset load in bit vectorized loop (#2127) --- taichi/transforms/bit_loop_vectorize.cpp | 76 ++++++++++++++++++++ tests/python/test_bit_array_vectorization.py | 52 ++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp index 20bd9879a209c..02336580cfc21 100644 --- a/taichi/transforms/bit_loop_vectorize.cpp +++ b/taichi/transforms/bit_loop_vectorize.cpp @@ -6,6 +6,7 @@ #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" +#include "taichi/ir/analysis.h" TLANG_NAMESPACE_BEGIN @@ -13,6 +14,7 @@ class BitLoopVectorize : public IRVisitor { public: int bit_vectorize; bool in_struct_for_loop; + StructForStmt *loop_stmt; PrimitiveType *bit_array_physical_type; BitLoopVectorize() { @@ -20,6 +22,7 @@ class BitLoopVectorize : public IRVisitor { invoke_default_visitor = true; bit_vectorize = 1; in_struct_for_loop = false; + loop_stmt = nullptr; bit_array_physical_type = nullptr; } @@ -45,6 +48,77 @@ class BitLoopVectorize : public IRVisitor { DataType new_ret_type(ptr_physical_type); ptr->ret_type = new_ret_type; ptr->is_bit_vectorized = true; + // check if j has offset + if (ptr->indices.size() == 2) { + auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[1], + loop_stmt, 1); + // TODO: temporarily we only support [j - 1] and [j + 1] + // the general case should be easy to implement + if (diff.linear_related() && diff.certain() && + (diff.low == 1 || diff.low == -1)) { + // construct ptr to x[i, j] + auto indices = ptr->indices; + indices[1] = loop_stmt->body->statements[1].get(); + auto base_ptr = + std::make_unique(ptr->snodes, indices); + base_ptr->ret_type = new_ret_type; + base_ptr->is_bit_vectorized = true; + // load x[i, j](base) + DataType load_data_type(bit_array_physical_type); + auto load_base = std::make_unique(base_ptr.get()); + load_base->ret_type = load_data_type; + // load x[i, j + 1](offsetted) + // since we are doing vectorization, the actual data should be x[i, + // j + vectorization_width] + auto offset_constant = + std::make_unique(TypedConstant(bit_vectorize)); + auto offset_index_opcode = + diff.low == -1 ? BinaryOpType::sub : BinaryOpType::add; + auto offset_index = std::make_unique( + offset_index_opcode, indices[1], offset_constant.get()); + indices[1] = offset_index.get(); + auto offset_ptr = + std::make_unique(ptr->snodes, indices); + offset_ptr->ret_type = new_ret_type; + offset_ptr->is_bit_vectorized = true; + auto load_offsetted = + std::make_unique(offset_ptr.get()); + load_offsetted->ret_type = load_data_type; + // create bit shift and bit and operations + auto base_shift_offset = + std::make_unique(TypedConstant(load_data_type, 1)); + auto base_shift_opcode = + diff.low == -1 ? BinaryOpType::bit_shl : BinaryOpType::bit_sar; + auto base_shift_op = std::make_unique( + base_shift_opcode, load_base.get(), base_shift_offset.get()); + + auto offsetted_shift_offset = std::make_unique( + TypedConstant(load_data_type, bit_vectorize - 1)); + auto offsetted_shift_opcode = + diff.low == -1 ? BinaryOpType::bit_sar : BinaryOpType::bit_shl; + auto offsetted_shift_op = std::make_unique( + offsetted_shift_opcode, load_offsetted.get(), + offsetted_shift_offset.get()); + + auto or_op = std::make_unique( + BinaryOpType::bit_or, base_shift_op.get(), + offsetted_shift_op.get()); + // modify IR + auto offsetted_shift_op_p = offsetted_shift_op.get(); + stmt->insert_before_me(std::move(base_ptr)); + stmt->insert_before_me(std::move(load_base)); + stmt->insert_before_me(std::move(offset_constant)); + stmt->insert_before_me(std::move(offset_index)); + stmt->insert_before_me(std::move(offset_ptr)); + stmt->insert_before_me(std::move(load_offsetted)); + stmt->insert_before_me(std::move(base_shift_offset)); + stmt->insert_before_me(std::move(base_shift_op)); + stmt->insert_before_me(std::move(offsetted_shift_offset)); + stmt->insert_before_me(std::move(offsetted_shift_op)); + stmt->replace_with(or_op.get()); + offsetted_shift_op_p->insert_after_me(std::move(or_op)); + } + } } } } @@ -72,10 +146,12 @@ class BitLoopVectorize : public IRVisitor { int old_bit_vectorize = bit_vectorize; bit_vectorize = stmt->bit_vectorize; in_struct_for_loop = true; + loop_stmt = stmt; bit_array_physical_type = stmt->snode->physical_type; stmt->body->accept(this); bit_vectorize = old_bit_vectorize; in_struct_for_loop = false; + loop_stmt = nullptr; bit_array_physical_type = nullptr; } diff --git a/tests/python/test_bit_array_vectorization.py b/tests/python/test_bit_array_vectorization.py index 350aa42cf9362..2e1b7dc0cbb9f 100644 --- a/tests/python/test_bit_array_vectorization.py +++ b/tests/python/test_bit_array_vectorization.py @@ -41,3 +41,55 @@ def verify(): init() assign_vectorized() verify() + + +@ti.test(require=ti.extension.quant) +def test_offset_load(): + ci1 = ti.type_factory.custom_int(1, False) + + x = ti.field(dtype=ci1) + y = ti.field(dtype=ci1) + z = ti.field(dtype=ci1) + + N = 4096 + n_blocks = 4 + bits = 32 + boundary_offset = 1024 + assert boundary_offset >= N // n_blocks + + block = ti.root.pointer(ti.ij, (n_blocks, n_blocks)) + block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array( + ti.j, bits, num_bits=bits).place(x) + block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array( + ti.j, bits, num_bits=bits).place(y) + block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array( + ti.j, bits, num_bits=bits).place(z) + + @ti.kernel + def init(): + for i, j in ti.ndrange((boundary_offset, N - boundary_offset), + (boundary_offset, N - boundary_offset)): + x[i, j] = ti.random(dtype=ti.i32) % 2 + + @ti.kernel + def assign_vectorized(dx: ti.template(), dy: ti.template()): + ti.bit_vectorize(32) + for i, j in x: + y[i, j] = x[i + dx, j + dy] + z[i, j] = x[i + dx, j + dy] + + @ti.kernel + def verify(dx: ti.template(), dy: ti.template()): + for i, j in ti.ndrange((boundary_offset, N - boundary_offset), + (boundary_offset, N - boundary_offset)): + assert y[i, j] == x[i + dx, j + dy] + + init() + assign_vectorized(0, 1) + verify(0, 1) + assign_vectorized(1, 0) + verify(1, 0) + assign_vectorized(0, -1) + verify(0, -1) + assign_vectorized(-1, 0) + verify(-1, 0)