Skip to content

Commit

Permalink
[type] Support offset load in bit vectorized loop (#2127)
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie authored Dec 31, 2020
1 parent 445bf32 commit 5d049ab
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
76 changes: 76 additions & 0 deletions taichi/transforms/bit_loop_vectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/analysis.h"

TLANG_NAMESPACE_BEGIN

class BitLoopVectorize : public IRVisitor {
public:
int bit_vectorize;
bool in_struct_for_loop;
StructForStmt *loop_stmt;
PrimitiveType *bit_array_physical_type;

BitLoopVectorize() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
bit_vectorize = 1;
in_struct_for_loop = false;
loop_stmt = nullptr;
bit_array_physical_type = nullptr;
}

Expand All @@ -45,6 +48,77 @@ class BitLoopVectorize : public IRVisitor {
DataType new_ret_type(ptr_physical_type);
ptr->ret_type = new_ret_type;
ptr->is_bit_vectorized = true;
// check if j has offset
if (ptr->indices.size() == 2) {
auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[1],
loop_stmt, 1);
// TODO: temporarily we only support [j - 1] and [j + 1]
// the general case should be easy to implement
if (diff.linear_related() && diff.certain() &&
(diff.low == 1 || diff.low == -1)) {
// construct ptr to x[i, j]
auto indices = ptr->indices;
indices[1] = loop_stmt->body->statements[1].get();
auto base_ptr =
std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
base_ptr->ret_type = new_ret_type;
base_ptr->is_bit_vectorized = true;
// load x[i, j](base)
DataType load_data_type(bit_array_physical_type);
auto load_base = std::make_unique<GlobalLoadStmt>(base_ptr.get());
load_base->ret_type = load_data_type;
// load x[i, j + 1](offsetted)
// since we are doing vectorization, the actual data should be x[i,
// j + vectorization_width]
auto offset_constant =
std::make_unique<ConstStmt>(TypedConstant(bit_vectorize));
auto offset_index_opcode =
diff.low == -1 ? BinaryOpType::sub : BinaryOpType::add;
auto offset_index = std::make_unique<BinaryOpStmt>(
offset_index_opcode, indices[1], offset_constant.get());
indices[1] = offset_index.get();
auto offset_ptr =
std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
offset_ptr->ret_type = new_ret_type;
offset_ptr->is_bit_vectorized = true;
auto load_offsetted =
std::make_unique<GlobalLoadStmt>(offset_ptr.get());
load_offsetted->ret_type = load_data_type;
// create bit shift and bit and operations
auto base_shift_offset =
std::make_unique<ConstStmt>(TypedConstant(load_data_type, 1));
auto base_shift_opcode =
diff.low == -1 ? BinaryOpType::bit_shl : BinaryOpType::bit_sar;
auto base_shift_op = std::make_unique<BinaryOpStmt>(
base_shift_opcode, load_base.get(), base_shift_offset.get());

auto offsetted_shift_offset = std::make_unique<ConstStmt>(
TypedConstant(load_data_type, bit_vectorize - 1));
auto offsetted_shift_opcode =
diff.low == -1 ? BinaryOpType::bit_sar : BinaryOpType::bit_shl;
auto offsetted_shift_op = std::make_unique<BinaryOpStmt>(
offsetted_shift_opcode, load_offsetted.get(),
offsetted_shift_offset.get());

auto or_op = std::make_unique<BinaryOpStmt>(
BinaryOpType::bit_or, base_shift_op.get(),
offsetted_shift_op.get());
// modify IR
auto offsetted_shift_op_p = offsetted_shift_op.get();
stmt->insert_before_me(std::move(base_ptr));
stmt->insert_before_me(std::move(load_base));
stmt->insert_before_me(std::move(offset_constant));
stmt->insert_before_me(std::move(offset_index));
stmt->insert_before_me(std::move(offset_ptr));
stmt->insert_before_me(std::move(load_offsetted));
stmt->insert_before_me(std::move(base_shift_offset));
stmt->insert_before_me(std::move(base_shift_op));
stmt->insert_before_me(std::move(offsetted_shift_offset));
stmt->insert_before_me(std::move(offsetted_shift_op));
stmt->replace_with(or_op.get());
offsetted_shift_op_p->insert_after_me(std::move(or_op));
}
}
}
}
}
Expand Down Expand Up @@ -72,10 +146,12 @@ class BitLoopVectorize : public IRVisitor {
int old_bit_vectorize = bit_vectorize;
bit_vectorize = stmt->bit_vectorize;
in_struct_for_loop = true;
loop_stmt = stmt;
bit_array_physical_type = stmt->snode->physical_type;
stmt->body->accept(this);
bit_vectorize = old_bit_vectorize;
in_struct_for_loop = false;
loop_stmt = nullptr;
bit_array_physical_type = nullptr;
}

Expand Down
52 changes: 52 additions & 0 deletions tests/python/test_bit_array_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,55 @@ def verify():
init()
assign_vectorized()
verify()


@ti.test(require=ti.extension.quant)
def test_offset_load():
ci1 = ti.type_factory.custom_int(1, False)

x = ti.field(dtype=ci1)
y = ti.field(dtype=ci1)
z = ti.field(dtype=ci1)

N = 4096
n_blocks = 4
bits = 32
boundary_offset = 1024
assert boundary_offset >= N // n_blocks

block = ti.root.pointer(ti.ij, (n_blocks, n_blocks))
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(x)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(y)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(z)

@ti.kernel
def init():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
x[i, j] = ti.random(dtype=ti.i32) % 2

@ti.kernel
def assign_vectorized(dx: ti.template(), dy: ti.template()):
ti.bit_vectorize(32)
for i, j in x:
y[i, j] = x[i + dx, j + dy]
z[i, j] = x[i + dx, j + dy]

@ti.kernel
def verify(dx: ti.template(), dy: ti.template()):
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
assert y[i, j] == x[i + dx, j + dy]

init()
assign_vectorized(0, 1)
verify(0, 1)
assign_vectorized(1, 0)
verify(1, 0)
assign_vectorized(0, -1)
verify(0, -1)
assign_vectorized(-1, 0)
verify(-1, 0)

0 comments on commit 5d049ab

Please sign in to comment.