Skip to content

Commit

Permalink
[type] Support bit-level load and store (#1996)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Oct 31, 2020
1 parent 969fd56 commit db7d088
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 66 deletions.
50 changes: 0 additions & 50 deletions misc/test_bit_struct.py

This file was deleted.

100 changes: 89 additions & 11 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,16 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
} else if (!is_real(from) && !is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
// TODO: implement casting into custom integer type
TI_ASSERT(!to->is<CustomIntType>());
auto from_size = 0;
if (from->is<CustomIntType>()) {
// TODO: replace 32 with a customizable type
from_size = 32;
} else {
from_size = data_type_size(from);
}
if (from_size < data_type_size(to)) {
llvm_val[stmt] = builder->CreateSExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
Expand Down Expand Up @@ -1079,14 +1088,48 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1);
TI_ASSERT(llvm_val[stmt->data]);
TI_ASSERT(llvm_val[stmt->ptr]);
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
if (auto cit = stmt->ptr->ret_type.ptr_removed()->cast<CustomIntType>()) {
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
builder->CreateCall(
get_runtime_function("set_partial_bits_b32"),
{builder->CreateBitCast(byte_ptr,
llvm::Type::getInt32PtrTy(*llvm_context)),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
llvm_val[stmt->data]});
} else {
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
}
}

void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
int width = stmt->width();
TI_ASSERT(width == 1);
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
llvm_val[stmt->ptr]);
if (auto cit = stmt->ret_type->cast<CustomIntType>()) {
// 1. load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)));
// 2. bit shifting
// first left shift `32 - (offset + num_bits)`
// then right shift `32 - num_bits`
auto bit_end = builder->CreateAdd(bit_offset,
tlctx->get_constant(cit->get_num_bits()));
auto left = builder->CreateSub(tlctx->get_constant(32), bit_end);
auto right = builder->CreateAdd(tlctx->get_constant(32),
tlctx->get_constant(-cit->get_num_bits()));
auto step1 = builder->CreateShl(bit_level_container, left);
llvm::Value *step2 = nullptr;
if (cit->get_is_signed())
step2 = builder->CreateAShr(step1, right);
else
step2 = builder->CreateLShr(step1, right);
llvm_val[stmt] = step2;
} else {
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
llvm_val[stmt->ptr]);
}
}

void CodeGenLLVM::visit(ElementShuffleStmt *stmt){
Expand Down Expand Up @@ -1120,6 +1163,8 @@ std::string CodeGenLLVM::get_runtime_snode_name(SNode *snode) {
return "Hash";
} else if (snode->type == SNodeType::bitmasked) {
return "Bitmasked";
} else if (snode->type == SNodeType::bit_struct) {
return "BitStruct";
} else {
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand Down Expand Up @@ -1204,21 +1249,54 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
}
llvm_val[stmt] = call(snode, llvm_val[stmt->input_snode], "lookup_element",
{llvm_val[stmt->input_index]});
} else if (snode->type == SNodeType::bit_struct) {
llvm_val[stmt] = parent;
} else {
TI_INFO(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
}
}

void CodeGenLLVM::visit(GetChStmt *stmt) {
auto ch = create_call(
stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(llvm_val[stmt->input_ptr],
if (stmt->output_snode->is_bit_level) {
// 1. create bit pointer struct
// struct bit_pointer {
// i8* byte_ptr;
// i32 offset;
// };
auto struct_type = llvm::StructType::get(
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
// 2. alloca the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);

// 3. store `input_ptr` into `bit_ptr_struct`
auto byte_ptr =
builder->CreateBitCast(llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context));

builder->CreateStore(
byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(0)}));
// 4. store `offset` in `bit_ptr_struct`
auto bit_struct = stmt->input_snode->dt.get_ptr()->cast<BitStructType>();
auto offset = bit_struct->get_member_bit_offset(
stmt->input_snode->child_id(stmt->output_snode));
builder->CreateStore(
tlctx->get_constant(offset),
builder->CreateGEP(bit_ptr_struct,
{tlctx->get_constant(0), tlctx->get_constant(1)}));
llvm_val[stmt] = bit_ptr_struct;
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context))});
llvm_val[stmt] = builder->CreateBitCast(
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
llvm_val[stmt] = builder->CreateBitCast(
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
}
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ std::string data_type_format(DataType dt) {
return "%f";
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return "%.12f";
} else if (dt->is<CustomIntType>()) {
return "%d";
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
4 changes: 3 additions & 1 deletion taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ inline bool is_integral(DataType dt) {
dt->is_primitive(PrimitiveTypeID::u8) ||
dt->is_primitive(PrimitiveTypeID::u16) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::u64);
dt->is_primitive(PrimitiveTypeID::u64) || dt->is<CustomIntType>();
}

inline bool is_signed(DataType dt) {
TI_ASSERT(is_integral(dt));
if (auto t = dt->cast<CustomIntType>())
return t->get_is_signed();
return dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::i32) ||
Expand Down
16 changes: 16 additions & 0 deletions taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ class LLVMModuleBuilder {
llvm::Value *call(const std::string &func_name, Args &&... args) {
return call(this->builder.get(), func_name, std::forward<Args>(args)...);
}

void read_bit_pointer(llvm::Value *ptr,
llvm::Value *&byte_ptr,
llvm::Value *&bit_offset) {
// 1. load byte pointer
auto byte_ptr_in_bit_struct = builder->CreateGEP(
ptr, {tlctx->get_constant(0), tlctx->get_constant(0)});
byte_ptr = builder->CreateLoad(byte_ptr_in_bit_struct);
TI_ASSERT(byte_ptr->getType()->getPointerElementType()->isIntegerTy(8));

// 2. load bit offset
auto bit_offset_in_bit_struct = builder->CreateGEP(
ptr, {tlctx->get_constant(0), tlctx->get_constant(1)});
bit_offset = builder->CreateLoad(bit_offset_in_bit_struct);
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
}
};

class RuntimeObject {
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ CompileConfig::CompileConfig() {
advanced_optimization = true;
max_vector_width = 8;
debug = false;
cfg_optimization = true;
check_out_of_bound = false;
lazy_compilation = true;
serial_schedule = false;
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ TLANG_NAMESPACE_BEGIN
struct CompileConfig {
Arch arch;
bool debug;
bool cfg_optimization;
bool check_out_of_bound;
int simd_width;
bool lazy_compilation;
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ void export_lang(py::module &m) {
.def_readwrite("arch", &CompileConfig::arch)
.def_readwrite("print_ir", &CompileConfig::print_ir)
.def_readwrite("debug", &CompileConfig::debug)
.def_readwrite("cfg_optimization", &CompileConfig::cfg_optimization)
.def_readwrite("check_out_of_bound", &CompileConfig::check_out_of_bound)
.def_readwrite("print_accessor_ir", &CompileConfig::print_accessor_ir)
.def_readwrite("print_evaluator_ir", &CompileConfig::print_evaluator_ir)
Expand Down
12 changes: 12 additions & 0 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,18 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
}

#include "internal_functions.h"

void set_partial_bits_b32(u32 *ptr, u32 offset, u32 bits, u32 value) {
u32 mask = ((((u32)1 << bits) - 1) << offset);
u32 new_value = 0;
u32 old_value = *ptr;
do {
old_value = *ptr;
new_value = (old_value & (~mask)) | (value << offset);
} while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true,
std::memory_order::memory_order_seq_cst,
std::memory_order::memory_order_seq_cst));
}
}

#endif
8 changes: 5 additions & 3 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ void compile_to_offloads(IRNode *ir,
print("Offloaded");
irpass::analysis::verify(ir);

irpass::cfg_optimization(ir, false);
print("Optimized by CFG");
irpass::analysis::verify(ir);
if (config.cfg_optimization) {
irpass::cfg_optimization(ir, false);
print("Optimized by CFG");
irpass::analysis::verify(ir);
}

irpass::flag_access(ir);
print("Access flagged II");
Expand Down
16 changes: 15 additions & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class TypeCheck : public IRVisitor {
}

void visit(GlobalStoreStmt *stmt) {
if (stmt->ptr->ret_type.ptr_removed()->is<CustomIntType>()) {
return;
}
auto promoted =
promoted_type(stmt->ptr->ret_type.ptr_removed(), stmt->data->ret_type);
auto input_type = stmt->data->ret_data_type_name();
Expand Down Expand Up @@ -264,9 +267,20 @@ class TypeCheck : public IRVisitor {
}

if (stmt->lhs->ret_type != stmt->rhs->ret_type) {
auto promote_custom_int_type = [&](Stmt *stmt, Stmt *hs) {
if (hs->ret_type->is<CustomIntType>()) {
if (hs->ret_type->cast<CustomIntType>()->get_is_signed())
return insert_type_cast_before(stmt, hs, get_data_type<int32>());
else
return insert_type_cast_before(stmt, hs, get_data_type<uint32>());
}
return hs;
};
stmt->lhs = promote_custom_int_type(stmt, stmt->lhs);
stmt->rhs = promote_custom_int_type(stmt, stmt->rhs);
auto ret_type = promoted_type(stmt->lhs->ret_type, stmt->rhs->ret_type);
if (ret_type != stmt->lhs->ret_type) {
// promote rhs
// promote lhs
auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type);
stmt->lhs = cast_stmt;
}
Expand Down
68 changes: 68 additions & 0 deletions tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import taichi as ti
import numpy as np


def test_simple_array():
ti.init(arch=ti.cpu, debug=True, print_ir=True, cfg_optimization=False)
ci13 = ti.type_factory_.get_custom_int_type(13, True)
cu19 = ti.type_factory_.get_custom_int_type(19, False)

x = ti.field(dtype=ci13)
y = ti.field(dtype=cu19)

N = 12

ti.root.dense(ti.i, N)._bit_struct(num_bits=32).place(x, y)

ti.get_runtime().materialize()

@ti.kernel
def set_val():
for i in range(N):
x[i] = -2**i
y[i] = 2**i - 1

@ti.kernel
def verify_val():
for i in range(N):
assert x[i] == -2**i
assert y[i] == 2**i - 1

set_val()
verify_val()


def test_custom_int_load_and_store():
ti.init(arch=ti.cpu, debug=True, print_ir=True, cfg_optimization=False)
ci13 = ti.type_factory_.get_custom_int_type(13, True)
cu14 = ti.type_factory_.get_custom_int_type(14, False)
ci5 = ti.type_factory_.get_custom_int_type(5, True)

x = ti.field(dtype=ci13)
y = ti.field(dtype=cu14)
z = ti.field(dtype=ci5)

test_case_np = np.array(
[[2**12 - 1, 2**14 - 1, -(2**3)], [2**11 - 1, 2**13 - 1, -(2**2)],
[0, 0, 0], [123, 4567, 8], [10, 31, 11]],
dtype=np.int32)

ti.root._bit_struct(num_bits=32).place(x, y, z)
test_case = ti.Vector.field(3, dtype=ti.i32, shape=len(test_case_np))
test_case.from_numpy(test_case_np)

@ti.kernel
def set_val(idx: ti.i32):
x[None] = test_case[idx][0]
y[None] = test_case[idx][1]
z[None] = test_case[idx][2]

@ti.kernel
def verify_val(idx: ti.i32):
assert x[None] == test_case[idx][0]
assert y[None] == test_case[idx][1]
assert z[None] == test_case[idx][2]

for idx in range(len(test_case_np)):
set_val(idx)
verify_val(idx)

0 comments on commit db7d088

Please sign in to comment.