Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] Support offset load in bit vectorized loop #2127

Merged
merged 17 commits into from
Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions taichi/transforms/bit_loop_vectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/analysis.h"

TLANG_NAMESPACE_BEGIN

class BitLoopVectorize : public IRVisitor {
public:
int bit_vectorize;
bool in_struct_for_loop;
StructForStmt *loop_stmt;
PrimitiveType *bit_array_physical_type;

BitLoopVectorize() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
bit_vectorize = 1;
in_struct_for_loop = false;
loop_stmt = nullptr;
bit_array_physical_type = nullptr;
}

Expand All @@ -45,6 +48,77 @@ class BitLoopVectorize : public IRVisitor {
DataType new_ret_type(ptr_physical_type);
ptr->ret_type = new_ret_type;
ptr->is_bit_vectorized = true;
// check if j has offset
if (ptr->indices.size() == 2) {
auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[1],
loop_stmt, 1);
// TODO: temporarily we only support [j - 1] and [j + 1]
// the general case should be easy to implement
if (diff.linear_related() && diff.certain() &&
(diff.low == 1 || diff.low == -1)) {
// construct ptr to x[i, j]
auto indices = ptr->indices;
indices[1] = loop_stmt->body->statements[1].get();
auto base_ptr =
std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
base_ptr->ret_type = new_ret_type;
base_ptr->is_bit_vectorized = true;
// load x[i, j](base)
DataType load_data_type(bit_array_physical_type);
auto load_base = std::make_unique<GlobalLoadStmt>(base_ptr.get());
load_base->ret_type = load_data_type;
// load x[i, j + 1](offsetted)
// since we are doing vectorization, the actual data should be x[i,
// j + vectorization_width]
auto offset_constant =
std::make_unique<ConstStmt>(TypedConstant(bit_vectorize));
auto offset_index_opcode =
diff.low == -1 ? BinaryOpType::sub : BinaryOpType::add;
auto offset_index = std::make_unique<BinaryOpStmt>(
offset_index_opcode, indices[1], offset_constant.get());
indices[1] = offset_index.get();
auto offset_ptr =
std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
offset_ptr->ret_type = new_ret_type;
offset_ptr->is_bit_vectorized = true;
auto load_offsetted =
std::make_unique<GlobalLoadStmt>(offset_ptr.get());
load_offsetted->ret_type = load_data_type;
// create bit shift and bit and operations
auto base_shift_offset =
std::make_unique<ConstStmt>(TypedConstant(load_data_type, 1));
auto base_shift_opcode =
diff.low == -1 ? BinaryOpType::bit_shl : BinaryOpType::bit_sar;
auto base_shift_op = std::make_unique<BinaryOpStmt>(
base_shift_opcode, load_base.get(), base_shift_offset.get());

auto offsetted_shift_offset = std::make_unique<ConstStmt>(
TypedConstant(load_data_type, bit_vectorize - 1));
auto offsetted_shift_opcode =
diff.low == -1 ? BinaryOpType::bit_sar : BinaryOpType::bit_shl;
auto offsetted_shift_op = std::make_unique<BinaryOpStmt>(
offsetted_shift_opcode, load_offsetted.get(),
offsetted_shift_offset.get());

auto or_op = std::make_unique<BinaryOpStmt>(
BinaryOpType::bit_or, base_shift_op.get(),
offsetted_shift_op.get());
// modify IR
auto offsetted_shift_op_p = offsetted_shift_op.get();
stmt->insert_before_me(std::move(base_ptr));
stmt->insert_before_me(std::move(load_base));
stmt->insert_before_me(std::move(offset_constant));
stmt->insert_before_me(std::move(offset_index));
stmt->insert_before_me(std::move(offset_ptr));
stmt->insert_before_me(std::move(load_offsetted));
stmt->insert_before_me(std::move(base_shift_offset));
stmt->insert_before_me(std::move(base_shift_op));
stmt->insert_before_me(std::move(offsetted_shift_offset));
stmt->insert_before_me(std::move(offsetted_shift_op));
stmt->replace_with(or_op.get());
offsetted_shift_op_p->insert_after_me(std::move(or_op));
}
}
}
}
}
Expand Down Expand Up @@ -72,10 +146,12 @@ class BitLoopVectorize : public IRVisitor {
int old_bit_vectorize = bit_vectorize;
bit_vectorize = stmt->bit_vectorize;
in_struct_for_loop = true;
loop_stmt = stmt;
bit_array_physical_type = stmt->snode->physical_type;
stmt->body->accept(this);
bit_vectorize = old_bit_vectorize;
in_struct_for_loop = false;
loop_stmt = nullptr;
bit_array_physical_type = nullptr;
}

Expand Down
52 changes: 52 additions & 0 deletions tests/python/test_bit_array_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,55 @@ def verify():
init()
assign_vectorized()
verify()


@ti.test(require=ti.extension.quant)
def test_offset_load():
ci1 = ti.type_factory.custom_int(1, False)

x = ti.field(dtype=ci1)
y = ti.field(dtype=ci1)
z = ti.field(dtype=ci1)

N = 4096
n_blocks = 4
bits = 32
boundary_offset = 1024
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
assert boundary_offset >= N // n_blocks

block = ti.root.pointer(ti.ij, (n_blocks, n_blocks))
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(x)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(y)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks)))._bit_array(
ti.j, bits, num_bits=bits).place(z)

@ti.kernel
def init():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
x[i, j] = ti.random(dtype=ti.i32) % 2

@ti.kernel
def assign_vectorized(dx: ti.template(), dy: ti.template()):
ti.bit_vectorize(32)
for i, j in x:
y[i, j] = x[i + dx, j + dy]
z[i, j] = x[i + dx, j + dy]

@ti.kernel
def verify(dx: ti.template(), dy: ti.template()):
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
assert y[i, j] == x[i + dx, j + dy]

init()
assign_vectorized(0, 1)
verify(0, 1)
assign_vectorized(1, 0)
verify(1, 0)
assign_vectorized(0, -1)
verify(0, -1)
assign_vectorized(-1, 0)
verify(-1, 0)